# TIMIT Dataset

## Load and minibatch TIMIT data

In [None]:
using Knet   
using Compat,GZip
using Images
using HDF5, JLD

In [None]:
Atype = KnetArray{Float32}
Ctype = Array{Float32}
gpu()


In [None]:
fileToRead = "timitdata.hdf5"

In [None]:
ytrn = h5read(fileToRead,"labelsTrn")
summary(ytrn)

In [None]:
ytst = h5read(fileToRead,"labelsTst")
summary(ytst)

In [None]:
xtrn = h5read(fileToRead,"imagesTrn")
summary(xtrn)

In [None]:
xtst = h5read(fileToRead,"imagesTst")
summary(xtst)

In [None]:
xtrn = reshape(xtrn, (40,14,1,121583))
xtst = reshape(xtst, (40,14,1,44125))
map(summary,(xtrn,xtst))

In [None]:
muX = mean(xtrn)

In [None]:
xtrn = xtrn.-muX
xtst = xtst.-muX;

In [None]:
mean(xtrn), mean(xtst)

In [None]:
quickview(x,i)=colorview(Gray,permutedims(x[:,:,1,i],(2,1)))

In [None]:
hcat([quickview(xtrn,i) for i=1:10]...)

In [None]:
ytrn[1:10]'

In [None]:
hcat([quickview(xtst,i) for i=1:10]...)

In [None]:
ytst[1:10]'

In [None]:
xtrn = convert(Atype, xtrn);
xtst = convert(Atype, xtst);
map(summary,(xtrn,xtst))

In [None]:
Nb = 40
dtst = minibatch(xtst,ytst,Nb;xtype=Atype) # [ (x1,y1), (x2,y2), ... ] where xi,yi are minibatches of Nb
dtrn = minibatch(xtrn,ytrn,Nb;xtype=Atype) # [ (x1,y1), (x2,y2), ... ] where xi,yi are minibatches of Nb
length(dtrn),length(dtst)

In [None]:
dtst

In [None]:
(x,y) = first(dtst)

In [None]:
summary(x)

In [None]:
summary(y)

In [None]:
knetgc()

## Train MNIST using MLP with dropout

In [None]:
# Loss functions
zeroone(w,data,predict) = 1 - accuracy(w,data,predict)
loss(w,data,predict) = mean(loss(w,x,y,predict) for (x,y) in data)
loss(w,x,y,predict; o...) = nll(predict(w,x;o...),y)
lossgrad = grad(loss)

In [None]:
function mlpdrop(w,x; pdrop=(0,0))
    x = mat(x)
    x = dropout(x,pdrop[1])
    for i=1:2:length(w)-2
        x = relu.(w[i]*x .+ w[i+1])
        x = dropout(x,pdrop[2])
    end
    return w[end-1]*x .+ w[end]
end

In [None]:
function winit(h...; std=0.01, x=40*14, y=45, atype=gpu()>=0 ? KnetArray{Float32} : Array{Float32})
    h = [x, h..., y]   # use winit(h1,h2,...,hn) for n hidden layer mlp
    w = Any[]
    for i=1:length(h)-1
        push!(w, std*randn(h[i+1],h[i]))
        push!(w, zeros(h[i+1],1))
    end
    map(atype, w)
end

In [None]:
wts=winit(1024,256,64) # gives weights and biases for an MLP

In [None]:
loss(wts,x,y,mlpdrop)  # Average loss for a single (x,y) minibatch

In [None]:
loss(wts,dtst,mlpdrop)  # Average loss for the whole test set

In [None]:
# Train model(w) with SGD and return a list containing w for every epoch
function train(w,data,predict; epochs=10,lr=0.15,o...)
    weights = Any[deepcopy(w)]
    #opts = map(x->Sgd(lr=lr), w)#sgd with default learning rate
    opts = map(x->Adam(), w)
    for epoch in 1:epochs
        for (x,y) in data
            g = lossgrad(w,x,y,predict;o...)
            update!(w,g,opts)  # w[i] = w[i] - lr * g[i]
        end
        push!(weights,deepcopy(w))
    end
    return weights
end

In [None]:
# srand(1)
@time trn1=train(wts,dtrn,mlpdrop;epochs=15,lr=0.15,pdrop=(0.2,0));


In [None]:
@time trnloss1 = [ loss(w,dtrn,mlpdrop) for w in trn1 ]
@time tstloss1 = [ loss(w,dtst,mlpdrop) for w in trn1 ]

In [None]:
@time trnerr1 = [ zeroone(w,dtrn,mlpdrop) for w in trn1 ]'

In [None]:
@time tsterr1 = [ zeroone(w,dtst,mlpdrop) for w in trn1 ]'

In [None]:
minimum(trnerr1),minimum(tsterr1)

In [None]:
wts = trn1 = trnloss1 = tstloss1 = trnerr1 = tsterr1 = nothing; knetgc()

## Train MNIST using a Baseline CNN Model

In [None]:
function convnet(w,x; activation=(relu,relu), pdrop=(0,0,0))    # pdrop[1]:input, pdrop[2]:conv, pdrop[3]:fc
    for i=1:2:length(w)
        if ndims(w[i]) == 4     # convolutional layer
            x = dropout(x, pdrop[i==1?1:2])
            x = conv4(w[i], x, padding=1) .+ w[i+1]
            x = pool(activation[1].(x))
        elseif ndims(w[i]) == 2 # fully connected layer
            if i == length(w)-1; x = dropout(x, pdrop[i==1?1:3]); end
            # x = dropout(x, pdrop[i==1?1:3])  Hinton used dropout only in the final FC layer!
            x = w[i]*mat(x) .+ w[i+1]
            if i < length(w)-1; x = activation[2].(x); end
        else
            error("Unknown layer type: $(size(w[i]))")
        end
    end
    return x
end;

In [None]:
# Weight initialization for multiple layers
# h[i] is an integer for a fully connected layer, a triple of integers for convolution filters
# Output is an array [w0,b0,w1,b1,...,wn,bn] where wi,bi is the weight matrix/tensor and bias vector for the i'th layer
function cinit(h...)  # use cinit(x,h1,h2,...,hn,y) for n hidden layer model
    w = Any[]
    x = h[1]
    for i=2:length(h)
        if isa(h[i],Tuple)
            (x1,x2,cx) = x
            (w1,w2,cy) = h[i]
            push!(w, xavier(w1,w2,cx,cy))
            push!(w, zeros(1,1,cy,1))
            x = (div(x1-w1+1+2,2),div(x2-w2+1+2,2),cy)
        elseif isa(h[i],Integer)
            push!(w, xavier(h[i],prod(x)))
            push!(w, zeros(h[i],1))
            x = h[i]
        else
            error("Unknown layer type: $(h[i])")
        end
    end
    map(Atype, w)
end;

In [None]:
cnnbase=cinit((40,14,1), (5,3,256), (5,3,256), (5,3,128), 328, 192, 45)

In [None]:
(x,y) = first(dtst)
loss(cnnbase,x,y,convnet)

In [None]:
srand(1)
@time weights=train(cnnbase,dtrn,convnet;epochs=10,lr=0.15,pdrop=(0,0,0.30))
#@time trnloss = [ loss(w,dtrn,convnet) for w in weights ]
#@time tstloss = [ loss(w,dtst,convnet) for w in weights ]
@time trnerr = [ zeroone(w,dtrn,convnet) for w in weights ]

In [None]:
@time tsterr = [ zeroone(w,dtst,convnet) for w in weights ]

In [None]:
minimum(trnerr),minimum(tsterr)

In [None]:
cnnbase = weights = trnloss = tstloss = trnerr = tsterr = nothing; knetgc()
# Knet.gpuinfo() # Knet.meminfo() # Knet.memdbg()

## Define the Capsule Network

In [None]:
Atype, Ctype

In [None]:
Nclass = 45  # number of classes
Vprimes = 16  # number of primary capsules vertically stacked (along z-axis), focusing on the same segment of images
Nax1primes = 12  # number of primary capsules along axis-1 (y-axis) of an image: Julia is column-major!
Nax2primes = 3  # number of primary capsules along axis-2 (x-axis) of an image
Nsegm = Nax1primes*Nax2primes  # number of segments per image: one primary capsule for each segment of each image
Nprimes = Nsegm*Vprimes  # total number of primary capsule
Dprime = 8  # dimension of a primary capsules
Dsecond = 16 # dimension of a secondary (higher layer) capsule
Nchannels = Vprimes*Dprime

In [None]:
function winitdecoder(h...; x=Nclass*Dsecond, y=28*14) 
    h = [x, h..., y]   # use winit(h1,h2,...,hn) for n hidden layer mlp
    w = Any[]
    for i=1:length(h)-1
        push!(w, xavier(h[i+1],h[i]))
        push!(w, zeros(h[i+1],1))
    end
    return w
    #map(Atype, w)
end

In [None]:
function wtsinit()
    wts = Any[ xavier(9,5,1,Nchannels),  zeros(1,1,Nchannels,1),
        xavier(9,5,Nchannels,Nchannels), zeros(1,1,Nchannels,1)]
    # W = 0.1*randn(Dprime,Nprimes,Nclass,Dsecond) 
    W = xavier(Dprime,Nprimes,Nclass,Dsecond)
    push!(wts,W)
    # append!(wts, winitdecoder(512,1024))  # for reconstruction
    wts = map(Atype, wts)

    return wts
end

In [None]:
function convLayer(w, x, strides=(1,2))
    # dropouts =(0,0)
    # paddings = (0,0)
    # strides = (1,2)
    for i=1:2:length(w)
        # x = dropout(x, dropouts[i==1?1:2])
        x = conv4(w[i], x, stride=strides[i==1?1:2]) .+ w[i+1]
        x = relu.(x)
    end
    # print(summary(x))
    return x
end

In [None]:
function softMax(X; axis=2) 
    X = X .- maximum(X, axis)
    prob = exp.(X) ./ sum(exp.(X), axis)
    return prob
end

In [None]:
function safeNorm(s; axis=4, eps=1e-7)
    sNorm2 = sum(abs2.(s),axis)
    sNorm = sqrt.(sNorm2+eps)
    return sNorm
end

In [None]:
function squash(s; axis=4, eps=1e-7)    
    sNorm2 = sum(abs2.(s),axis)
    sNorm = sqrt.(sNorm2+eps)
    #sNorm, sNorm2 = safeNorm(s,axis=axis)
    sUnit = s ./ sNorm
    sFactor = sNorm2./(sNorm2.+1)
    V = sFactor.*sUnit
    return V
end

In [None]:
function maxidx(matrix)  # find index of maximums along axis-1 (columns)
    M = size(matrix,1)
    N = size(matrix,2)  
    idxes = zeros(Int8,N)
    maxes = maximum(matrix, 1)

    for j=1:N
        for i=1:M
            if maxes[j]==matrix[i,j]
            idxes[j] = i
            end
        end
    end
    return idxes
end

In [None]:
(x,y) = first(dtst)

In [None]:
map(summary,first(dtst))

In [None]:
conv1hot(y) = convert(Ctype, sparse(convert(Vector{Int},y),1:length(y),one(eltype(y)),Nclass,length(y)))

In [None]:
function decode(w,vtodeco)
    #x = mat(x)
    for i=1:2:length(w)-2
        vtodeco = relu.(w[i]*vtodeco .+ w[i+1])
    end
    return sigm.(w[end-1]*vtodeco .+ w[end])
end

In [None]:
function mask(v, y1ht)
    vmasked = permutedims(y1ht.*v , (1,3,2))
    vtodecode = reshape(vmasked,(Dsecond*Nclass,Nb))
    return vtodecode
end

In [None]:
function capsnet(w, x)
    con = convLayer(w[1:4],x)
    pri = reshape(con,(Nax1primes,Nax2primes,Dprime,Vprimes,Nb))
    pri = permutedims(pri,(1,2,4,3,5))
    pri = reshape(pri,(Nprimes,Dprime,Nb))
    pri = permutedims(pri,(2,1,3))
    pri = squash(pri, axis=1)  # along the contents of primary capsules
    pri = reshape(pri,(Dprime,Nprimes,1,1,Nb))

    UHat = convert(Atype, ones(Dprime,Nprimes,Nclass,Dsecond,Nb))
    W = w[5].*UHat  # achieved tiling to higher dimensions!
    
    UHat = pri.*W
    UHat = sum(UHat,1)  # achieved affine transformations without matmul!
    UHat = permutedims(reshape(UHat,(Nprimes,Nclass,Dsecond,Nb)), (1,2,4,3))
    
    B = convert(Atype, zeros(Nprimes,Nclass,Nb))     
    
    C = softMax(B, axis=2) # C is normalized along 2nd dim (classes)
    S = C.*UHat
    s = sum(S,1)
    v = squash(s)
    
    maxiter = 1
    for r=1:maxiter
        A = v.* UHat
        Agreement = sum(A,4)
        Agreement = reshape(Agreement, (Nprimes,Nclass,Nb))
        B = B .+ Agreement
        
        C = softMax(B, axis=2)
        S = C.*UHat
        s = sum(S,1)
        v = squash(s)
    end    
    
    yprob = safeNorm(v)
    yprob = reshape(yprob, (Nclass,Nb))
    v = reshape(v, (Nclass,Nb,Dsecond));

    return yprob, v
    
end

In [None]:
m_plus = 0.9
m_minus = 0.1
lambda = 0.5
rloss = 0.0005;

In [None]:
function lossCaps(w, x, y, predict; training = true)
    yp, v = predict(w, x)
    yp = convert(Ctype,yp)
    nb = size(x,4)
    y1hot = conv1hot(y)
    lossmargin = sum((abs2.(relu.(m_plus.-yp)).*y1hot)+lambda.*(abs2.(relu.(yp.-m_minus))).*(1-y1hot))
    # xmat = mat(x)
    # if training
    #     vmasked = mask(v, y1hot)
    # else
    #     yp1hot = conv1hot(maxidx(yp))
    #     vmasked = mask(v, yp1hot)
    # end
    # xr = decode(w[6:11], vmasked)
    # lossreconstruction = sum(abs2.(xmat.-xr)) 
    return lossmargin # + rloss*lossreconstruction
end

In [None]:
function accurate(ygold, yhat)
    correct = 0.0
    Nb = length(ygold)
    for i=1:Nb
        correct += (ygold[i]==yhat[i]) ? 1.0 : 0.0
    end
    return correct / Nb
end

In [None]:
function accurate(w,x,y,predict)
    yprb, v = predict(w,x)
    yprb = convert(Ctype,yprb)
    yhat = maxidx(yprb)
    return accurate(y, yhat)
end

In [None]:
# Loss functions for Capsnet
pererror(w,data,predict) = 1 - accurate(w,data,predict)
accurate(w,data,predict) = mean(accurate(w,x,y,predict) for (x,y) in data)
lossCaps(w,data,predict; o...) = mean(lossCaps(w,x,y,predict;o...) for (x,y) in data)
lossCapsgrad = grad(lossCaps)

In [None]:
function trainCaps(w,data,predict; epochs=2,lr=0.15,o...)
    weights = Any[deepcopy(w)]
    #opts = map(x->Sgd(lr=lr), w)#sgd with default learning rate
    opts = map(x->Adam(), w)
    for epoch in 1:epochs
        for (x,y) in data
            g = lossCapsgrad(w,x,y,predict;o...)
            update!(w,g,opts)  # w[i] = w[i] - lr * g[i]
        end
        push!(weights,deepcopy(w))
    end
    return weights
end

In [None]:
srand(1)
wtscap = wtsinit()

In [None]:
yprb1, v1 = capsnet(wtscap,x)

In [None]:
lossCaps(wtscap,x,y,capsnet;training=true)

In [None]:
accurate(wtscap,x,y,capsnet)

In [None]:
# wtscap = wnext  # if read from results file
@time weights = trainCaps(wtscap,dtrn,capsnet; epochs=5,lr=0.15)  # obviously increase epochs as required

In [None]:
@time trnerr = [ pererror(w,dtrn,capsnet) for w in weights[2:end] ]  # [2:end]

In [None]:
@time tsterr = [ pererror(w,dtst,capsnet) for w in weights[2:end] ]  # [2:end]

In [None]:
wnext = convert(Array{Ctype},weights[end])  # before writing to results file

In [None]:
######### OVERWRITES
jldopen("timitresults.jld", "w") do file
    write(file, "wnext", wnext)  
    write(file, "trnerr", trnerr)  
    write(file, "tsterr", tsterr)    
end

In [None]:
D = load("timitresults.jld")  # to read previously stored results

In [None]:
tsterr = D["tsterr"]  # D is a dictionary

In [None]:
wnext = convert(Array{Atype},D["wnext"])  # after reading from results file, use to initialize wtscap

In [None]:
using Plots
plotly()

In [None]:
Plots.plot([trnerr tsterr],ylim=(0,1.0),linewidth=3,labels=[:Train_Error :Test_Error],xlabel="Epochs", ylabel="TIMIT Errors")