In [1]:
using GeometricFlux;
using Flux;
using Flux: onecold, crossentropy, throttle, @epochs;
using JLD2;  # use v0.1.2
using SparseArrays;
using Statistics: mean;
using LightGraphs: SimpleGraphs, adjacency_matrix;

Načtení dat - dataset Cora

In [2]:
@load "data/cora_features.jld2" features;
@load "data/cora_labels.jld2" labels;
@load "data/cora_graph.jld2" g;

## Preprocessing data
train_X = Float32.(features);  # dim: num_features * num_nodes
train_y = Float32.(labels);  # dim: target_catg * num_nodes

adj_mat = Matrix{Float32}(adjacency_matrix(g));

Nastavení parametrů modelu
- Šířka skryté vrstvy
- Počet výstupních tříd
- Počet trénovacích epoch

In [3]:
hidden_layer_width = 16;
num_classes = 7;
epochs = 20;

Definice modelu pomocí metod balíčku `GeometricFlux.jl`
- Jedna vrstva GCN šířky `hidden` s aktivační funkcí ReLU
- Dropout
- Druhá vrstva GCN šířky `num_classes` s lineární aktivací
- Softmax funkce

In [4]:
## Model
model = Chain(
    GCNConv(adj_mat, size(train_X, 1) => hidden_layer_width, relu),
    Dropout(0.5),
    GCNConv(adj_mat, hidden_layer_width => num_classes),
    softmax
);

Definice ztrátové funkce - cross-entropy. Jako průběžnou míru budeme ukazovat přesnost na trénovacích datech.

In [5]:
## Loss
loss(x, y) = crossentropy(model(x), y);
accuracy(x, y) = mean(onecold(model(x)) .== onecold(y));

Trénujeme pomocí metody ADAM s `η = 0.05`.

In [6]:
## Training
train_data = [(train_X, train_y)];
opt = ADAM(0.05);
evalcb() = @show(accuracy(train_X, train_y));

@epochs epochs Flux.train!(loss, Flux.params(model), train_data, opt, cb=throttle(evalcb, 10));

┌ Info: Epoch 1
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.08825701624815362


┌ Info: Epoch 2
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.09933530280649926


┌ Info: Epoch 3
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.15398818316100443


┌ Info: Epoch 4
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.20901033973412111


┌ Info: Epoch 5
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.26292466765140327


┌ Info: Epoch 6
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.3076070901033973


┌ Info: Epoch 7
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.34268833087149186


┌ Info: Epoch 8
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.3744460856720827


┌ Info: Epoch 9
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.4069423929098966


┌ Info: Epoch 10
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.4287296898079764


┌ Info: Epoch 11
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.45605612998522893


┌ Info: Epoch 12
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.48005908419497784


┌ Info: Epoch 13
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.4977843426883309


┌ Info: Epoch 14
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.5158788774002954


┌ Info: Epoch 15
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.5243722304283605


┌ Info: Epoch 16
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.5324963072378139


┌ Info: Epoch 17
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.5350812407680945


┌ Info: Epoch 18
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.5409896602658789


┌ Info: Epoch 19
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.5428360413589365


┌ Info: Epoch 20
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.5443131462333826


Kód je modifikací příkladů balíčku [GeometricFlux.jl](https://github.com/yuehhua/GeometricFlux.jl).