In [31]:
import Pkg; Pkg.activate(".")

[32m[1m  Activating[22m[39m project at `~/SAFT_ML`


In [32]:
using Flux
using Flux: onecold, onehotbatch, logitcrossentropy
using Flux: DataLoader
using GraphNeuralNetworks
using MLDatasets
using MLUtils
using LinearAlgebra, Random, Statistics

ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"  # don't ask for dataset download confirmation
Random.seed!(17) # for reproducibility

TaskLocalRNG()

In [33]:
dataset = TUDataset("MUTAG")

dataset TUDataset:
  name        =>    MUTAG
  metadata    =>    Dict{String, Any} with 1 entry
  graphs      =>    188-element Vector{MLDatasets.Graph}
  graph_data  =>    (targets = "188-element Vector{Int64}",)
  num_nodes   =>    3371
  num_edges   =>    7442
  num_graphs  =>    188

In [34]:
dataset.graphs[1]

Graph:
  num_nodes   =>    17
  num_edges   =>    38
  edge_index  =>    ("38-element Vector{Int64}", "38-element Vector{Int64}")
  node_data   =>    (targets = "17-element Vector{Int64}",)
  edge_data   =>    (targets = "38-element Vector{Int64}",)

In [35]:
dataset.graph_data.targets |> union

2-element Vector{Int64}:
  1
 -1

In [36]:
g1, y1 = dataset[1] #get the first graph and target

(graphs = Graph(17, 38), targets = 1)

In [37]:
reduce(vcat, g.node_data.targets for (g, _) in dataset) |> union

7-element Vector{Int64}:
 0
 1
 2
 3
 4
 5
 6

In [38]:
reduce(vcat, g.edge_data.targets for (g, _) in dataset) |> union


4-element Vector{Int64}:
 0
 1
 2
 3

In [39]:
graphs = mldataset2gnngraph(dataset)

188-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}:
 GNNGraph(17, 38) with targets: 17-element, targets: 38-element data
 GNNGraph(13, 28) with targets: 13-element, targets: 28-element data
 GNNGraph(13, 28) with targets: 13-element, targets: 28-element data
 GNNGraph(19, 44) with targets: 19-element, targets: 44-element data
 GNNGraph(11, 22) with targets: 11-element, targets: 22-element data
 GNNGraph(28, 62) with targets: 28-element, targets: 62-element data
 GNNGraph(16, 34) with targets: 16-element, targets: 34-element data
 GNNGraph(20, 44) with targets: 20-element, targets: 44-element data
 GNNGraph(12, 26) with targets: 12-element, targets: 26-element data
 GNNGraph(17, 38) with targets: 17-element, targets: 38-element data
 ⋮
 GNNGraph(28, 66) with targets: 28-element, targets: 66-element data
 GNNGraph(11, 22) with targets: 11-element, targets: 22-element data
 GNNGraph(14, 30) with targets: 14-element, targets: 30-element data
 GNNGraph(22, 50) with t

In [9]:
graphs = mldataset2gnngraph(dataset)
graphs = [GNNGraph(g,
                    ndata = Float32.(onehotbatch(g.ndata.targets, 0:6)),
                    edata = nothing)
            for g in graphs]
y = onehotbatch(dataset.graph_data.targets, [-1, 1])

2×188 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
 ⋅  1  1  ⋅  1  ⋅  1  ⋅  1  ⋅  ⋅  ⋅  ⋅  …  ⋅  ⋅  1  ⋅  1  1  ⋅  ⋅  1  1  ⋅  1
 1  ⋅  ⋅  1  ⋅  1  ⋅  1  ⋅  1  1  1  1     1  1  ⋅  1  ⋅  ⋅  1  1  ⋅  ⋅  1  ⋅

In [54]:
g = graphs[1]
display(fieldnames(typeof(g)))

g.graph

(:graph, :num_nodes, :num_edges, :num_graphs, :graph_indicator, :ndata, :edata, :gdata)

([1, 1, 2, 2, 3, 3, 4, 4, 4, 5  …  13, 13, 13, 14, 14, 15, 15, 15, 16, 17], [2, 6, 1, 3, 2, 4, 3, 5, 10, 4  …  12, 14, 15, 13, 9, 13, 16, 17, 15, 15], nothing)

In [53]:
# Returns a OneHotMatrix where kth column of the matrix is onehot(xs[k], labels). This is a sparse matrix, which stores just a Vector{UInt32} containing the indices of the nonzero elements.
collect(onehotbatch(graphs[1].ndata.targets, 0:6))

7×17 Matrix{Bool}:
 1  1  1  1  1  1  1  1  1  1  1  1  1  1  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  1
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0

In [45]:
graphs[1].ndata.targets

17-element Vector{Int64}:
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 1
 2
 2

In [None]:
# 7x17, 7 features per node

In [26]:
graphs[1]

GNNGraph:
  num_nodes: 17
  num_edges: 38
  ndata:
	x = 7×17 Matrix{Float32}

In [10]:
train_data, test_data = splitobs((graphs, y), at = 150, shuffle = true) |> getobs

((GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}[GNNGraph(20, 44) with x: 7×20 data, GNNGraph(17, 38) with x: 7×17 data, GNNGraph(23, 54) with x: 7×23 data, GNNGraph(22, 50) with x: 7×22 data, GNNGraph(26, 60) with x: 7×26 data, GNNGraph(12, 26) with x: 7×12 data, GNNGraph(25, 56) with x: 7×25 data, GNNGraph(25, 58) with x: 7×25 data, GNNGraph(23, 54) with x: 7×23 data, GNNGraph(11, 22) with x: 7×11 data  …  GNNGraph(21, 44) with x: 7×21 data, GNNGraph(19, 44) with x: 7×19 data, GNNGraph(12, 26) with x: 7×12 data, GNNGraph(16, 34) with x: 7×16 data, GNNGraph(17, 36) with x: 7×17 data, GNNGraph(11, 22) with x: 7×11 data, GNNGraph(22, 50) with x: 7×22 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(17, 38) with x: 7×17 data, GNNGraph(23, 48) with x: 7×23 data], Bool[0 0 … 0 0; 1 1 … 1 1]), (GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}[GNNGraph(16, 34) with x: 7×16 data, GNNGraph(20, 44) with x: 7×20 data, GNNGraph(12, 26) with x: 7×12 data, GNNGraph(23, 54) with x: 7

In [11]:
train_loader = DataLoader(train_data, batchsize = 32, shuffle = true)
test_loader = DataLoader(test_data, batchsize = 32, shuffle = false)

2-element DataLoader(::Tuple{Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}, OneHotArrays.OneHotMatrix{UInt32, Vector{UInt32}}}, batchsize=32)
  with first element:
  (32-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}, 2×32 OneHotMatrix(::Vector{UInt32}) with eltype Bool,)

In [12]:
vec_gs, _ = first(train_loader)

(GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}[GNNGraph(13, 28) with x: 7×13 data, GNNGraph(21, 44) with x: 7×21 data, GNNGraph(20, 46) with x: 7×20 data, GNNGraph(23, 54) with x: 7×23 data, GNNGraph(21, 44) with x: 7×21 data, GNNGraph(26, 60) with x: 7×26 data, GNNGraph(18, 40) with x: 7×18 data, GNNGraph(21, 44) with x: 7×21 data, GNNGraph(15, 32) with x: 7×15 data, GNNGraph(19, 44) with x: 7×19 data  …  GNNGraph(20, 44) with x: 7×20 data, GNNGraph(22, 50) with x: 7×22 data, GNNGraph(12, 26) with x: 7×12 data, GNNGraph(16, 34) with x: 7×16 data, GNNGraph(23, 54) with x: 7×23 data, GNNGraph(13, 26) with x: 7×13 data, GNNGraph(24, 50) with x: 7×24 data, GNNGraph(20, 46) with x: 7×20 data, GNNGraph(17, 36) with x: 7×17 data, GNNGraph(10, 20) with x: 7×10 data], Bool[1 0 … 0 1; 0 1 … 1 0])

In [13]:
MLUtils.batch(vec_gs)

GNNGraph:
  num_nodes: 561
  num_edges: 1228
  num_graphs: 32
  ndata:
	x = 7×561 Matrix{Float32}

In [20]:
function create_model(nin, nh, nout)
    GNNChain(GraphConv(nin => nh, relu),
             GraphConv(nh => nh, relu),
             GraphConv(nh => nh),
             GlobalPool(mean), # Average the node features
             Dropout(0.5),
             Dense(nh, nout))
end

function eval_loss_accuracy(model, data_loader, device)
    loss = 0.0
    acc = 0.0
    ntot = 0
    for (g, y) in data_loader
        g, y = MLUtils.batch(g) |> device, y |> device
        n = length(y)
        ŷ = model(g, g.ndata.x)
        loss += logitcrossentropy(ŷ, y) * n
        acc += mean((ŷ .> 0) .== y) * n
        ntot += n
    end
    return (loss = round(loss / ntot, digits = 4),
            acc = round(acc * 100 / ntot, digits = 2))
end

function train!(model; epochs = 200, η = 1e-2, infotime = 10)
    # device = Flux.gpu # uncomment this for GPU training
    device = Flux.cpu
    model = model |> device
    opt = Flux.setup(Adam(1e-3), model)

    function report(epoch)
        train = eval_loss_accuracy(model, train_loader, device)
        test = eval_loss_accuracy(model, test_loader, device)
        @info (; epoch, train, test)
    end

    report(0)
    for epoch in 1:epochs
        for (g, y) in train_loader
            g, y = MLUtils.batch(g) |> device, y |> device
            grad = Flux.gradient(model) do model
                ŷ = model(g, g.ndata.x)
                logitcrossentropy(ŷ, y)
            end
            Flux.update!(opt, model, grad[1])
        end
        epoch % infotime == 0 && report(epoch)
    end
end

train! (generic function with 1 method)

In [21]:
nin = 7
nh = 64
nout = 2
model = create_model(nin, nh, nout)
train!(model)

┌ Info: (epoch = 0, train = (loss = 0.5502, acc = 64.67), test = (loss = 0.5153, acc = 68.42))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35
┌ Info: (epoch = 10, train = (loss = 0.4899, acc = 71.33), test = (loss = 0.4532, acc = 76.32))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35


┌ Info: (epoch = 20, train = (loss = 0.4096, acc = 79.33), test = (loss = 0.4089, acc = 78.95))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35
┌ Info: (epoch = 30, train = (loss = 0.3521, acc = 82.33), test = (loss = 0.3787, acc = 76.32))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35


┌ Info: (epoch = 40, train = (loss = 0.3098, acc = 84.33), test = (loss = 0.3612, acc = 81.58))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35
┌ Info: (epoch = 50, train = (loss = 0.281, acc = 85.67), test = (loss = 0.3501, acc = 80.26))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35


┌ Info: (epoch = 60, train = (loss = 0.2825, acc = 86.33), test = (loss = 0.3688, acc = 77.63))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35
┌ Info: (epoch = 70, train = (loss = 0.253, acc = 88.33), test = (loss = 0.3542, acc = 80.26))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35


┌ Info: (epoch = 80, train = (loss = 0.2525, acc = 87.33), test = (loss = 0.3784, acc = 81.58))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35
┌ Info: (epoch = 90, train = (loss = 0.2208, acc = 89.0), test = (loss = 0.3515, acc = 82.89))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35


┌ Info: (epoch = 100, train = (loss = 0.2199, acc = 90.0), test = (loss = 0.3709, acc = 81.58))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35
┌ Info: (epoch = 110, train = (loss = 0.2022, acc = 91.0), test = (loss = 0.3607, acc = 81.58))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35


┌ Info: (epoch = 120, train = (loss = 0.2005, acc = 90.0), test = (loss = 0.3625, acc = 85.53))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35
┌ Info: (epoch = 130, train = (loss = 0.2009, acc = 90.33), test = (loss = 0.3668, acc = 85.53))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35


┌ Info: (epoch = 140, train = (loss = 0.1857, acc = 91.0), test = (loss = 0.375, acc = 86.84))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35
┌ Info: (epoch = 150, train = (loss = 0.1834, acc = 91.67), test = (loss = 0.3843, acc = 86.84))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35


┌ Info: (epoch = 160, train = (loss = 0.1801, acc = 91.33), test = (loss = 0.4097, acc = 82.89))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35
┌ Info: (epoch = 170, train = (loss = 0.1769, acc = 91.67), test = (loss = 0.4267, acc = 82.89))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35


┌ Info: (epoch = 180, train = (loss = 0.168, acc = 92.0), test = (loss = 0.4087, acc = 86.84))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35
┌ Info: (epoch = 190, train = (loss = 0.1607, acc = 91.0), test = (loss = 0.4164, acc = 86.84))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35


┌ Info: (epoch = 200, train = (loss = 0.1642, acc = 92.0), test = (loss = 0.404, acc = 86.84))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:35
