# AffNIST Dataset

## Load and minibatch AffNIST data

In [None]:
using Knet   
using Compat,GZip
using Images
using MAT

In [None]:
Atype = KnetArray{Float32}
Ctype = Array{Float32}
gpu()

In [None]:
"Where to download affNIST from"
affnisturl = "http://www.cs.toronto.edu/~tijmen/affNIST/32x/transformed/"

"Where to download affNIST to"
affnistdir = "affNIST/transformed/"

In [None]:
# There are 32 training, 32 test and 32 validation files. Each training file contains 50K images, and
# Each test file contains 10K images, Only the first train and test files are read for now!
trndata = matread("affNIST/transformed/training_batches/1.mat")
tstdata = matread("affNIST/transformed/test_batches/1.mat")
trndata["affNISTdata"]

In [None]:
tstdata["affNISTdata"]

In [None]:
xtrnraw = trndata["affNISTdata"]["image"];
xtstraw = tstdata["affNISTdata"]["image"];
ytrn, ytst = trndata["affNISTdata"]["label_int"], tstdata["affNISTdata"]["label_int"]

In [None]:
# Represent 0 as 10th class
ytrn[ytrn.==0]=10;
ytst[ytst.==0]=10;

In [None]:
ytrn = reshape(ytrn, (length(ytrn)))
ytst = reshape(ytst, (length(ytst)));

In [None]:
map(summary,(xtrnraw,xtstraw))  # (50K + 10K) 40x40 images stored in columns of size 1600

In [None]:
xtrn = convert(Array{Float32}, reshape(xtrnraw ./ 255, (40,40,1, div(length(xtrnraw), 1600))));
xtst = convert(Array{Float32}, reshape(xtstraw ./ 255, (40,40,1, div(length(xtstraw), 1600))));

map(summary,(xtrn,xtst))

In [None]:
mnistview(x,i)=colorview(Gray,permutedims(x[:,:,1,i],(2,1)))

In [None]:
hcat([mnistview(xtrn,i) for i=1:5]...)

In [None]:
ytrn[1:5]'

In [None]:
hcat([mnistview(xtst,i) for i=1:5]...)

In [None]:
ytst[1:5]'

In [None]:
Nb = 80  # batch size
dtst = minibatch(xtst,ytst,Nb;xtype=Atype) # [ (x1,y1), (x2,y2), ... ] where xi,yi are minibatches of Nb
dtrn = minibatch(xtrn,ytrn,Nb;xtype=Atype) # [ (x1,y1), (x2,y2), ... ] where xi,yi are minibatches of Nb
length(dtrn),length(dtst)

In [None]:
dtst

In [None]:
# This is the first minibatch
map(summary,first(dtst))  # (x,y) pair where x: 4-D Float32 array with X,Y,C,N  y: 1-D integer array

In [None]:
(x,y) = first(dtst)

In [None]:
summary(y)

In [None]:
Knet.gpuinfo();

## Define the Capsule Network

In [None]:
Atype, Ctype

In [None]:
Nclass = 10  # number of classes
Vprimes = 32  # number of primary capsules vertically stacked (along z-axis), focusing on the same segment of images
Nax1primes = 6  # number of primary capsules along axis-1 (y-axis) of an image: Julia is column-major!
Nax2primes = 6  # number of primary capsules along axis-2 (x-axis) of an image
Nsegm = Nax1primes*Nax2primes  # number of segments per image: one primary capsule for each segment of each image
Nprimes = Nsegm*Vprimes  # total number of primary capsule
Dprime = 8  # dimension of a primary capsules
Dsecond = 16 # dimension of a secondary (higher layer) capsule
Nchannels = Vprimes*Dprime  # number of conv.layer channels to fit Vprimes capsules of dim. Dprime

In [None]:
function winitdecoder(h...; x=Nclass*Dsecond, y=40*40) # AffNIST images are 40x40
    h = [x, h..., y]   # use winit(h1,h2,...,hn) for n hidden layer mlp
    w = Any[]
    for i=1:length(h)-1
        push!(w, xavier(h[i+1],h[i]))
        push!(w, zeros(h[i+1],1))
    end
    return w
end

In [None]:
function wtsinit()
    wts = Any[ xavier(7,7,1,Nchannels),  zeros(1,1,Nchannels,1),
        xavier(7,7,Nchannels,Nchannels), zeros(1,1,Nchannels,1)]
    W = 0.1*randn(Dprime,Nprimes,Nclass,Dsecond) 
    push!(wts,W)
    append!(wts, winitdecoder(512,1024))  
    wts = map(Atype, wts)

    return wts
end

In [None]:
function convLayer(w, x; strides=(2,2), paddings=(1,0))
    # dropouts =(0,0)
    for i=1:2:length(w)
        # x = dropout(x, dropouts[i==1?1:2])
        x = conv4(w[i], x, stride=strides[i==1?1:2], padding=paddings[i==1?1:2]) .+ w[i+1]
        x = relu.(x)
    end
    return x
end

In [None]:
function softMax(X; axis=2) 
    X = X .- maximum(X, axis)
    prob = exp.(X) ./ sum(exp.(X), axis)
    return prob
end

In [None]:
function safeNorm(s; axis=4, eps=1e-7)
    sNorm2 = sum(abs2.(s),axis)
    sNorm = sqrt.(sNorm2+eps)
    return sNorm
end

In [None]:
function squash(s; axis=4, eps=1e-7)    
    sNorm2 = sum(abs2.(s),axis)
    sNorm = sqrt.(sNorm2+eps)
    #sNorm, sNorm2 = safeNorm(s,axis=axis)
    sUnit = s ./ sNorm
    sFactor = sNorm2./(sNorm2.+1)
    V = sFactor.*sUnit
	return V
end

In [None]:
function maxidx(matrix)  # find index of maximums along axis-1 (columns)
    M = size(matrix,1)
    N = size(matrix,2)  
    idxes = zeros(Int8,N)
    maxes = maximum(matrix, 1)

    for j=1:N
        for i=1:M
            if maxes[j]==matrix[i,j]
            idxes[j] = i
            end
        end
    end
    return idxes
end

In [None]:
conv1hot(y) = convert(KnetArray{Float32}, sparse(convert(Vector{Int},y),1:length(y),one(eltype(y)),Nclass,length(y)))

In [None]:
function decode(w,vtodeco)  # decoder for reconstruction
    for i=1:2:length(w)-2
        vtodeco = relu.(w[i]*vtodeco .+ w[i+1])
    end
    return sigm.(w[end-1]*vtodeco .+ w[end])
end

In [None]:
function mask(v, y1ht)  # filter for reconstruction
    vmasked = permutedims(y1ht.*v , (1,3,2))
    vtodecode = reshape(vmasked,(Dsecond*Nclass,Nb))
    return vtodecode
end

In [None]:
function capsnet(w, x)
    con = convLayer(w[1:4],x)
    pri = reshape(con,(Nax1primes,Nax2primes,Dprime,Vprimes,Nb))  # pri = reshape(con,(6,6,8,32,80))
    pri = permutedims(pri,(1,2,4,3,5))
    
    pri = reshape(pri,(Nprimes,Dprime,Nb))  # pri = reshape(pri,(1152,8,80))
    pri = permutedims(pri,(2,1,3))
    pri = squash(pri, axis=1)  # along the contents of primary capsules
    
    pri = reshape(pri,(Dprime,Nprimes,1,1,Nb))  # pri = reshape(pri,(8,1152,1,1,80))

    UHat = convert(Atype, ones(Dprime,Nprimes,Nclass,Dsecond,Nb))  # UHat = convert(Atype, ones(8,1152,10,16,80))
    W = w[5].*UHat  # achieved tiling to higher dimensions by multiplying ones!
    
    UHat = pri.*W  # prediction vectors
    UHat = sum(UHat,1)  # achieved affine transformations without matmul()
    UHat = permutedims(reshape(UHat,(Nprimes,Nclass,Dsecond,Nb)), (1,2,4,3))  # UHat = permutedims(reshape(UHat,(1152,10,16,80)), (1,2,4,3))
    
    B = convert(Atype, zeros(Nprimes,Nclass,Nb))  # B = convert(Atype, zeros(1152,10,80))     
    
    C = softMax(B, axis=2) # C is normalized along 2nd dim (classes)
    S = C.*UHat
    s = sum(S,1)
    v = squash(s)
    
    maxiter = 1
    for r=1:maxiter
        A = v.* UHat
        Agreement = sum(A,4)
        Agreement = reshape(Agreement, (Nprimes,Nclass,Nb))  # Agreement = reshape(Agreement, (1152,10,80))
        B = B .+ Agreement
        
        C = softMax(B, axis=2)
        S = C.*UHat
        s = sum(S,1)
        v = squash(s)
    end    
    
    yprob = safeNorm(v)
    yprob = reshape(yprob, (Nclass,Nb))  # yprob = reshape(yprob, (10,80))
    v = reshape(v, (Nclass,Nb,Dsecond))  # v = reshape(v, (10,80,16))

    return yprob, v 
end

In [None]:
# Loss function parameters
m_plus = 0.9
m_minus = 0.1
lambda = 0.5
rloss = 0.0005;

In [None]:
function lossCaps(w, x, y, predict; training = true)
    yp, v = predict(w, x)
    nb = size(x,4)
    y1hot = conv1hot(reshape(y,(80)))
    lossmargin = sum((abs2.(relu.(m_plus.-yp)).*y1hot)+lambda.*(abs2.(relu.(yp.-m_minus))).*(1-y1hot))
    xmat = mat(x)
    if training
        vmasked = mask(v, y1hot)
    else
        yp1hot = conv1hot(maxidx(yp))
        vmasked = mask(v, yp1hot)
    end
    xr = decode(w[6:11], vmasked)
    lossreconstruction = sum(abs2.(xmat.-xr)) 
    return lossmargin + rloss*lossreconstruction
end

In [None]:
function accurate(ygold, yhat)
    correct = 0.0
    Nb = length(ygold)
    ygold = reshape(ygold,(Nb))
    for i=1:Nb
        correct += (ygold[i]==yhat[i]) ? 1.0 : 0.0
    end
    return correct / Nb
end

In [None]:
function accurate(w,x,y,predict)
    yprb, v = predict(w,x)
    yhat = maxidx(yprb)
    return accurate(y, yhat)
end

In [None]:
# Loss functions for Capsnet
pererror(w,data,predict) = 1 - accurate(w,data,predict)
accurate(w,data,predict) = mean(accurate(w,x,y,predict) for (x,y) in data)
lossCaps(w,data,predict; o...) = mean(lossCaps(w,x,y,predict;o...) for (x,y) in data)
lossCapsgrad = grad(lossCaps)

In [None]:
# Train model(w) with SGD/Adam and return a list containing w for every epoch
function trainCaps(w,data,predict; epochs=3,lr=0.15,o...)
    weights = Any[deepcopy(w)]
    #opts = map(x->Sgd(lr=lr), w)
    opts = map(x->Adam(), w)
    for epoch in 1:epochs
        for (x,y) in data
            g = lossCapsgrad(w,x,y,predict;o...)
            update!(w,g,opts)  # w[i] = w[i] - lr * g[i]
        end
        push!(weights,deepcopy(w))
    end
    return weights
end

In [None]:
wtscap = wtsinit()

In [None]:
yprb1, v1 = capsnet(wtscap,x)

In [None]:
conv1hot(maxidx(yprb1))

In [None]:
lossCaps(wtscap,x,y,capsnet)

In [None]:
accurate(wtscap,x,y,capsnet)

In [None]:
# wtscap = weights = trnloss = tstloss = trnerr = tsterr = nothing; knetgc()

In [None]:
@time weights = trainCaps(wtscap,dtrn,capsnet;epochs=5,lr=0.15)

In [None]:
@time trnerr = [ pererror(w,dtrn,capsnet) for w in weights[2:end] ]

In [None]:
@time tsterr = [ pererror(w,dtst,capsnet) for w in weights[2:end] ]

In [None]:
@time trnloss = [ lossCaps(w,dtrn,capsnet) for w in weights ]
@time tstloss = [ lossCaps(w,dtst,capsnet) for w in weights ]

In [None]:
using Plots
plotly()

In [None]:
Plots.plot([trnerr tsterr],ylim=(0,0.1),linewidth=3,labels=[:Train_Error :Test_Error],xlabel="Epochs", ylabel="AFFNIST Errors, Loss w/Reconstruction") 