In [None]:
using Pkg
for p in ["Knet", "Plots", "IterTools","PyCall"]
    if !haskey(Pkg.installed(),p)
        Pkg.add(p);
    end
end
using DelimitedFiles
using Knet: KnetArray, accuracy, progress, minibatch, cycle, adam, xavier_uniform, progress!, @save, @load
using Plots
using IterTools: ncycle, takenth, take
using Base.Iterators: flatten
using LinearAlgebra

include("utils.jl")
include("models.jl")

In [None]:
function mytrain!(model, data, epochs, lr, window_size)
    early_stop_counter = 0
    prev_val_loss = 0
    iter = 0
    
    trnloss = []
    valloss = []
    
    flag = true

    function task()     
        
        append!(trnloss, model(data))
        v_loss = val_loss(model, data)
        append!(valloss, v_loss) 
        
        if v_loss >= prev_val_loss
            early_stop_counter = early_stop_counter + 1
        else
            early_stop_counter = 0
        end
        if early_stop_counter == window_size
            flag = false
        end 
        iter = iter + 1
        prev_val_loss = v_loss
        return flag
        
    end
        
    training = adam(model, ncycle(data, epochs), lr=lr)
    progress!(flag = task() for x in (x for (i,x) in enumerate(training)) if flag)
    return 1:iter, trnloss, valloss
end

In [None]:
# TODO: take user inputs
struct args
    epochs
    lr
    weight_decay
    hidden
    pdrop
    window_size
end

arguments = args(200, 0.01, 5e-4, 16, 0.5, 10)
function val_loss(g::GCN,x,y)
    output = g(x)[:, idx_val]
    nll(output, y[idx_val]) + (arguments.weight_decay * sum(g.layer1.w .* g.layer1.w)) 
end  
function val_loss(g::GCN, d)
    mean(val_loss(g,x,y) for (x,y) in d)
end

function test_loss(g::GCN,x,y)
    output = g(x)[:, idx_test]
    nll(output, y[idx_test]) + (arguments.weight_decay * sum(g.layer1.w .* g.layer1.w)) 
end  
function test_loss(g::GCN,d)
    mean(test_loss(g,x,y) for (x,y) in d)
end

(g::GCN)(x,y) = nll(g(x)[:, idx_train], y[idx_train]) + (arguments.weight_decay * sum(g.layer1.w .* g.layer1.w)) 

In [None]:
#####################################################################################################################

In [None]:
adj, features, labels, idx_train, idx_val, idx_test = load_dataset("cora")

In [None]:
model = GCN(size(features,1),
            arguments.hidden,
            size(labels,2),
            adj,
            arguments.pdrop)

In [None]:
labels_decoded = mapslices(argmax, labels ,dims=2)[:]

In [None]:
data =  minibatch(features, labels_decoded[:], length(labels_decoded))

In [None]:
iters, trnloss, vallos = mytrain!(model, data, arguments.epochs, arguments.lr, arguments.window_size)

In [None]:
plot(iters, [trnloss, vallos] , xlim=(1:3),labels=[:trn :val :tst], xlabel="epochs", ylabel="loss")

In [None]:
output = model(features)
accuracy(output[:,idx_train], labels_decoded[idx_train])

In [None]:
accuracy(output[:,idx_test], labels_decoded[idx_test])

In [None]:
png("cora")

In [None]:
#####################################################################################################################

In [None]:
adj, features, labels, idx_train, idx_val, idx_test = load_dataset("citeseer")

In [None]:
model = GCN(size(features,1),
            arguments.hidden,
            size(labels,2),
            adj,
            arguments.pdrop)

In [None]:
labels_decoded = mapslices(argmax, labels ,dims=2)[:]

In [None]:
data =  minibatch(features, labels_decoded[:], length(labels_decoded))

In [None]:
iters, trnloss, vallos = mytrain!(model, data, arguments.epochs, arguments.lr, arguments.window_size)

In [None]:
plot(iters, [trnloss, vallos] , xlim=(1:3),labels=[:trn :val :tst], xlabel="epochs", ylabel="loss")

In [None]:
output = model(features)
accuracy(output[:,idx_train], labels_decoded[idx_train])

In [None]:
accuracy(output[:,idx_test], labels_decoded[idx_test])

In [None]:
png("citeseer")

In [None]:
#####################################################################################################################

In [None]:
adj, features, labels, idx_train, idx_val, idx_test = load_dataset("pubmed")

In [None]:
model = GCN(size(features,1),
            arguments.hidden,
            size(labels,2),
            adj,
            arguments.pdrop)

In [None]:
labels_decoded = mapslices(argmax, labels ,dims=2)[:]

In [None]:
data =  minibatch(features, labels_decoded[:], length(labels_decoded))

In [None]:
iters, trnloss, vallos = mytrain!(model, data, arguments.epochs, arguments.lr, arguments.window_size)

In [None]:
plot(iters, [trnloss, vallos] , xlim=(1:3),labels=[:trn :val :tst], xlabel="epochs", ylabel="loss")

In [None]:
output = model(features)
accuracy(output[:,idx_train], labels_decoded[idx_train])

In [None]:
accuracy(output[:,idx_test], labels_decoded[idx_test])

In [None]:
png("pubmed")