In [1]:
using GeometricFlux;
using Flux;
using Flux: onecold, crossentropy, throttle, @epochs;
using JLD2;  # use v0.1.2
using SparseArrays;
using Statistics: mean;
using LightGraphs: SimpleGraphs, adjacency_matrix;

Načtení dat - dataset Cora

In [2]:
@load "data/cora_features.jld2" features;
@load "data/cora_labels.jld2" labels;
@load "data/cora_graph.jld2" g;

train_X = Float32.(features);  # dim: num_features * num_nodes
train_y = Float32.(labels);  # dim: target_catg * num_nodes

adj_mat = Matrix{Float32}(adjacency_matrix(g));

Nastavení parametrů modelu
- Šířka skryté vrstvy
- Počet výstupních tříd
- Počet trénovacích epoch

In [3]:
hidden_layer_width = 16;
num_classes = 7;
epochs = 20;

Definice modelu pomocí metod balíčku `GeometricFlux.jl`
- Jedna vrstva GCN šířky `hidden_layer_width` s aktivační funkcí ReLU
- Dropout
- Druhá vrstva GCN šířky `num_classes` s lineární aktivací
- Softmax funkce

In [4]:
model = Chain(
    GCNConv(adj_mat, size(train_X, 1) => hidden_layer_width, relu),
    Dropout(0.5),
    GCNConv(adj_mat, hidden_layer_width => num_classes),
    softmax
);

Definice ztrátové funkce - cross-entropy. Jako průběžnou míru budeme ukazovat přesnost na trénovacích datech.

In [5]:
loss(x, y) = crossentropy(model(x), y);
accuracy(x, y) = mean(onecold(model(x)) .== onecold(y));

Trénujeme pomocí metody ADAM s `η = 0.05`.

In [6]:
train_data = [(train_X, train_y)];
opt = ADAM(0.05);
evalcb() = @show(accuracy(train_X, train_y));

@epochs epochs Flux.train!(loss, Flux.params(model), train_data, opt, cb=throttle(evalcb, 10));

┌ Info: Epoch 1
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.1532496307237814


┌ Info: Epoch 2
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.21344165435745938


┌ Info: Epoch 3
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.27141802067946824


┌ Info: Epoch 4
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.32533234859675036


┌ Info: Epoch 5
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.36152141802067944


┌ Info: Epoch 6
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.39032496307237813


┌ Info: Epoch 7
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.4169128508124077


┌ Info: Epoch 8
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.43353028064992616


┌ Info: Epoch 9
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.4519940915805022


┌ Info: Epoch 10
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.46454948301329396


┌ Info: Epoch 11
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.47525849335302806


┌ Info: Epoch 12
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.48338257016248154


┌ Info: Epoch 13
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.4885524372230428


┌ Info: Epoch 14
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.49556868537666177


┌ Info: Epoch 15
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.49963072378138845


┌ Info: Epoch 16
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.5051698670605613


┌ Info: Epoch 17
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.5107090103397341


┌ Info: Epoch 18
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.5158788774002954


┌ Info: Epoch 19
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.5210487444608567


┌ Info: Epoch 20
└ @ Main /home/marekdedic/.julia/packages/Flux/05b38/src/optimise/train.jl:114


accuracy(train_X, train_y) = 0.5273264401772526


Kód je modifikací příkladů balíčku [GeometricFlux.jl](https://github.com/yuehhua/GeometricFlux.jl).