In [1]:
using Pkg
for p in ("Knet","Images","ImageMagick", "MAT", "LinearAlgebra")
    haskey(Pkg.installed(),p) || Pkg.add(p)
end
using Knet, MAT, Images, Random
using Base.Iterators: flatten, cycle, take
using IterTools
using LinearAlgebra
using Statistics: mean
using Plots; default(fmt=:png,ls=:auto)
include(Knet.dir("data","mnist.jl"))  # Load data
import Base: length, size, iterate

In [2]:
using Knet: Data

In [3]:
dtrn,dtst = mnistdata(;batchsize = 25);  

┌ Info: Loading MNIST...
└ @ Main /home/cankucuksozen/.julia/packages/Knet/vxHRi/data/mnist.jl:33


In [4]:
(x,y) = first(dtst)
println.(summary.((x,y)));
println(summary(dtrn))

28×28×1×25 KnetArray{Float32,4}
25-element Array{UInt8,1}
2400-element Data{Tuple{KnetArray{Float32,4},Array{UInt8,1}}}


In [5]:
struct Conv
    w
    stride
    padding
end

function Conv(w1::Int, w2::Int, cx::Int, cy::Int; stride = 1, padding = 0)
    w = param(w1, w2, cx, cy)
    return Conv(w, stride, padding)
end

function (c::Conv)(x)
    return conv4(c.w, x ; padding = c.padding, stride = c.stride)
end

In [6]:
struct Deconv
    w
    stride
    padding
end

function Deconv(w1::Int, w2::Int, cy::Int, cx::Int; stride = 1, padding = 0)
    w = param(w1, w2, cy, cx)
    return Deconv(w, stride, padding)
end

function (c::Deconv)(x)
    return deconv4(c.w, x ; padding = c.padding, stride = c.stride)
end

In [7]:
struct Dense
    w
    b
end
function Dense(i::Int,o::Int)     
    w = param(o,i)
    b = param0(o)
    return Dense(w,b)
end

function (d::Dense)(x)
    return d.w * mat(x) .+ d.b
end

In [8]:
struct Chain
    layers
    Chain(layers...) = new(layers)
end
(c::Chain)(x) = (for l in c.layers; x = l(x); end; x)
(c::Chain)(x,y) = nll(c(x),y, average = true)
(c::Chain)(d::Data) = mean(c(x,y) for (x,y) in d)

In [9]:
function rel_logits_2d(flat_q, rel, kernel_size)
    h2, d, Nh, b = size(flat_q)
    flat_q = reshape(flat_q, (1, h2, d, Nh, b))
    rel = flatten_rel(rel, kernel_size)
    rel_logits = rel .* flat_q
    rel_logits = reshape(sum(rel_logits, dims = 3),(h2, h2, Nh, b))
    return rel_logits
end

rel_logits_2d (generic function with 1 method)

In [10]:
function flatten_hw(input)
    h, w, d, Nh, b = size(input)
    new_size = (h*w, d, Nh, b)
    return reshape(input, new_size)
end

flatten_hw (generic function with 1 method)

In [11]:
function flatten_rel(rel, kernel_size)
    h2, w2, c, Nh, b = size(rel)
    temp_size = (kernel_size, kernel_size, kernel_size, kernel_size, c, Nh, b)
    rel = permutedims(reshape(rel, temp_size),[1,3,2,4,5,6,7])
    new_size = (h2, w2, c, Nh, b)
    rel = reshape(rel, new_size)
    return rel
end

flatten_rel (generic function with 1 method)

In [12]:
function split_heads_2d(inputs, Nh)
    h, w, d, b = size(inputs)
    ret_shape = (h, w, floor(Int,d/Nh), Nh, b)
    out = reshape(inputs, ret_shape)
    return out
end

split_heads_2d (generic function with 1 method)

In [13]:
function combine_heads_2d(inputs)
    h, w, dh, Nh, b  = size(inputs)
    ret_shape = (h,w, dh*Nh, b)
    return reshape(inputs, ret_shape)
end

combine_heads_2d (generic function with 1 method)

In [14]:
struct self_attention_2d
    conv_q
    conv_k
    conv_v
    conv_rel
    deconv_rel
    conv_attn
    kernel_size
    stride
    padding
    dk
    dv
    Nh
    dkh
    dvh
end

In [15]:
function self_attention_2d(input_dims, kernel_size, stride, padding, Nh, dk, dv)
    conv_q = Conv(1,1, input_dims, dk)
    conv_k = Conv(1,1, input_dims, dk)
    conv_v = Conv(1,1, input_dims, dv)
    conv_rel = Conv(5, 5, input_dims, dk; padding = 2)
    deconv_rel = Deconv(kernel_size, kernel_size, dk, dk; stride = kernel_size)
    conv_attn = Conv(1,1, dv, dv)
    
    dkh = floor(Int, dk/Nh)
    dvh = floor(Int, dv/Nh)
    
    stride = stride
    padding = padding
    
    return self_attention_2d(conv_q, conv_k, conv_v, conv_rel, deconv_rel,
                                conv_attn, kernel_size, stride, padding, dk, dv, Nh, dkh, dvh)
end

self_attention_2d

In [16]:
function odims(input, kernel_size, stride, padding, dv)
    inh,inw,inc,b = size(input)
    out_dims_h = Int(((inh-kernel_size) + 2*pad)/stride + 1)
    out_dims = (out_dims_h^2, dv, b)
    return out_dims, out_dims_h
end

odims (generic function with 1 method)

In [17]:
function (s::self_attention_2d)(x)
    
    out_dims, out_h = odims(x, s.kernel_size, s.stride, s.padding, s.dv)
    out = nothing
    imh, imw, imc, b = size(x)
    
    for i = 1:s.stride:imh-s.kernel_size+1
        for j = 1:s.stride:imw-s.kernel_size+1

            x_patch = x[i:i+s.kernel_size-1, j:j+s.kernel_size-1, :, :]
            
            _, _, _, b = size(x_patch)
            q = s.conv_q(x_patch)
            k = s.conv_k(x_patch)
            v = s.conv_v(x_patch)

            rel = s.conv_rel(x_patch)
            rel = s.deconv_rel(rel)

            q = q .* (s.dkh ^ -0.5)

            q = split_heads_2d(q,s.Nh)
            k = split_heads_2d(k,s.Nh)
            v = split_heads_2d(v,s.Nh)
            rel = split_heads_2d(rel,s.Nh)

            flat_q = flatten_hw(q)
            flat_k = flatten_hw(k)
            flat_v = flatten_hw(v)

            logits = bmm(flat_q, flat_k, transB = true)
            logits = permutedims(logits, [2,1,3,4])

            rel_logits = rel_logits_2d(flat_q, rel, s.kernel_size)

            logits += rel_logits

            weights = softmax(logits; dims = 1)

            attn_out = bmm(weights, flat_v, transA = true)

            attn_out = reshape(attn_out, (s.kernel_size, s.kernel_size, s.dvh, s.Nh, :))
            attn_out = combine_heads_2d(attn_out) 

            attn_out = s.conv_attn(attn_out)
            
            attn_out = pool(attn_out; window = s.kernel_size, mode = 2)
            
            attn_out = reshape(attn_out, (1, s.dv, :))
            
            if out == nothing
                out = attn_out
            else
                out = cat(out, attn_out; dims = 1)
            end
            
        end
    end
    
    out = permutedims(reshape(out, (out_h, out_h, s.dv, :)), [2, 1, 3, 4])
    
    return out
end

In [18]:
mutable struct self_Attn_Net
    activation
    conv_i
    self_attn
    avg_pool
    fc
end

In [19]:
function self_Attn_Net(num_classes = 10)
    
    activation = relu
    conv_i = Conv(6,6,1,64; stride = 2, padding = 2)
    self_attn = self_attention_2d(64, 4, 8, 128, 128)
    avg_pool = pool
    fc = Dense(128, num_classes)
    
    return self_Attn_Net(activation, conv_i, self_attn, avg_pool, fc)
    
end

self_Attn_Net

In [20]:
function (r::self_Attn_Net)(x)
    x = r.conv_i(x)
    x = r.activation.(x)
    x = r.self_attn(x)
    x = r.activation.(x)
    x = r.avg_pool(x; window = 6, mode = 2)
    x = mat(x)
    x = r.fc(x)
    return x
end

In [21]:
function (r::self_Attn_Net)(x,y)
    scores = r(x)
    loss = nll(scores, y)
    return loss
end

function (r::self_Attn_Net)(d::Data)
    mean_loss = mean(r(x,y) for (x,y) in d)
    return mean_loss
end


In [22]:
function train(file, dtrn, dtst, epochs; lr = 0.001)
    net = self_Attn_Net()
    avgloss = []
    sumloss = 0
    currloss = []
    len = length(dtrn)
    iteration = 0
    ind = []
    for e = 1:epochs
        for (i,v) in enumerate(adam(net, dtrn; lr = lr))
            iteration += 1
            push!(ind, iteration)
            push!(currloss,v)
            sumloss += v
            avg_temp = sumloss / iteration
            push!(avgloss, avg_temp)
            println("iteration: $i / $len   loss: $v")
        end
        acc = accuracy(net, dtst)
        println("epoch: $e    test_acc: $acc")
        Knet.gc()
    end
    Knet.save(file,"net", net)
    return ind, currloss, avgloss
end

train (generic function with 1 method)

In [23]:
ind, currloss, avgloss = train("mnist_stride.jld2", dtrn, dtst, 3)

iteration: 1 / 2400   loss: 2.2992837
iteration: 2 / 2400   loss: 2.304644
iteration: 3 / 2400   loss: 2.2968507
iteration: 4 / 2400   loss: 2.296804
iteration: 5 / 2400   loss: 2.2750115
iteration: 6 / 2400   loss: 2.3167093
iteration: 7 / 2400   loss: 2.2964826
iteration: 8 / 2400   loss: 2.30612
iteration: 9 / 2400   loss: 2.2734694
iteration: 10 / 2400   loss: 2.2977426
iteration: 11 / 2400   loss: 2.3079855
iteration: 12 / 2400   loss: 2.2847757
iteration: 13 / 2400   loss: 2.29953
iteration: 14 / 2400   loss: 2.3231761
iteration: 15 / 2400   loss: 2.3014867
iteration: 16 / 2400   loss: 2.290228
iteration: 17 / 2400   loss: 2.335736
iteration: 18 / 2400   loss: 2.2679193
iteration: 19 / 2400   loss: 2.2566788
iteration: 20 / 2400   loss: 2.2976248
iteration: 21 / 2400   loss: 2.3386312
iteration: 22 / 2400   loss: 2.3025866
iteration: 23 / 2400   loss: 2.2879744
iteration: 24 / 2400   loss: 2.279912
iteration: 25 / 2400   loss: 2.3329585
iteration: 26 / 2400   loss: 2.2911625
iter

(Any[1, 2, 3, 4, 5, 6, 7, 8, 9, 10  …  7191, 7192, 7193, 7194, 7195, 7196, 7197, 7198, 7199, 7200], Any[2.2992837f0, 2.304644f0, 2.2968507f0, 2.296804f0, 2.2750115f0, 2.3167093f0, 2.2964826f0, 2.30612f0, 2.2734694f0, 2.2977426f0  …  0.29266876f0, 0.17620596f0, 0.026909351f0, 0.021444358f0, 0.0047896574f0, 0.0058618356f0, 0.7544183f0, 0.33267215f0, 0.11194071f0, 0.17458138f0], Any[2.2992837f0, 2.3019638f0, 2.3002594f0, 2.2993956f0, 2.2945187f0, 2.298217f0, 2.2979693f0, 2.298988f0, 2.2961528f0, 2.2963119f0  …  0.3194391f0, 0.3194192f0, 0.31937853f0, 0.31933713f0, 0.3192934f0, 0.31924987f0, 0.31931034f0, 0.31931219f0, 0.3192834f0, 0.3192633f0])

In [None]:
plot([avgloss], ylim=(0.0 , 3.0 ),
     labels=[:trn_avgloss], xlabel = "Iterations", ylabel = "Loss")