In [8]:
using Flux, Statistics
using Flux.Data: DataLoader
using Flux: onehotbatch, onecold, logitcrossentropy, throttle, @epochs
using Base.Iterators: repeated
using Parameters: @with_kw
using CUDAapi
using MLDatasets
if has_cuda()  # Check if CUDA is available
    @info "CUDA is on"
    import CuArrays  # If CUDA is available, import CuArrays
    CuArrays.allowscalar(false)
end

In [21]:
@with_kw mutable struct Args
    η::Float64 = 3e-4       # learning rate
    batchsize::Int = 1024   # batch size
    epochs::Int = 10        # number of epochs
    device::Function = gpu  # set as gpu, if gpu available
end

Args

In [22]:
function getdata(args)
    # Loading Dataset
    xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
    xtest, ytest = MLDatasets.MNIST.testdata(Float32)

    # Reshape Data for flatten the each image into linear array
    xtrain = Flux.flatten(xtrain)
    xtest = Flux.flatten(xtest)

    # One-hot-encode the labels
    ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9)

    # Batching
    train_data = DataLoader(xtrain, ytrain, batchsize=args.batchsize, shuffle=true)
    test_data = DataLoader(xtest, ytest, batchsize=args.batchsize)

    return train_data, test_data
end

getdata (generic function with 1 method)

In [12]:
function build_model(; imgsize=(28,28,1), nclasses=10)
    return Chain(
            Dense(prod(imgsize), 32, relu),
            Dense(32, nclasses))
end

build_model (generic function with 1 method)

In [13]:
function loss_all(dataloader, model)
    l = 0f0
    for (x,y) in dataloader
        l += logitcrossentropy(model(x), y)
    end
    l/length(dataloader)
end

loss_all (generic function with 1 method)

In [16]:
function accuracy(data_loader, model)
    acc = 0
    for (x,y) in data_loader
        acc += sum(onecold(cpu(model(x))) .== onecold(cpu(y)))*1 / size(x,2)
    end
    acc/length(data_loader)
end

accuracy (generic function with 1 method)

In [17]:
function train(; kws...)
    # Initializing Model parameters 
    args = Args(; kws...)

    # Load Data
    train_data,test_data = getdata(args)

    # Construct model
    m = build_model()
    train_data = args.device.(train_data)
    test_data = args.device.(train_data)
    m = args.device(m)
    loss(x,y) = logitcrossentropy(m(x), y)
    
    ## Training
    evalcb = () -> @show(loss_all(train_data, m))
    opt = ADAM(args.η)

    @epochs args.epochs Flux.train!(loss, params(m), train_data, opt, cb = evalcb)

    @show accuracy(train_data, m)

    @show accuracy(test_data, m)
end

train (generic function with 1 method)

In [18]:
cd(@__DIR__)
train()

┌ Info: Epoch 1
└ @ Main /Users/yasu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss_all(train_data, m) = 2.2981153f0
loss_all(train_data, m) = 2.2775574f0
loss_all(train_data, m) = 2.257917f0
loss_all(train_data, m) = 2.2389674f0
loss_all(train_data, m) = 2.2206185f0
loss_all(train_data, m) = 2.2027218f0
loss_all(train_data, m) = 2.185217f0
loss_all(train_data, m) = 2.1680458f0
loss_all(train_data, m) = 2.151162f0
loss_all(train_data, m) = 2.1345675f0
loss_all(train_data, m) = 2.118097f0
loss_all(train_data, m) = 2.101754f0
loss_all(train_data, m) = 2.0854404f0
loss_all(train_data, m) = 2.0692234f0
loss_all(train_data, m) = 2.0530176f0
loss_all(train_data, m) = 2.0368364f0
loss_all(train_data, m) = 2.020643f0
loss_all(train_data, m) = 2.0044088f0
loss_all(train_data, m) = 1.988065f0
loss_all(train_data, m) = 1.9716523f0
loss_all(train_data, m) = 1.9551747f0
loss_all(train_data, m) = 1.9386064f0
loss_all(train_data, m) = 1.9219378f0
loss_all(train_data, m) = 1.9051024f0
loss_all(train_data, m) = 1.8881837f0
loss_all(train_data, m) = 1.8712518f0
loss_all(train_data

┌ Info: Epoch 2
└ @ Main /Users/yasu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


1.3210677f0
loss_all(train_data, m) = 1.3079668f0
loss_all(train_data, m) = 1.2950487f0
loss_all(train_data, m) = 1.282315f0
loss_all(train_data, m) = 1.2697116f0
loss_all(train_data, m) = 1.2572827f0
loss_all(train_data, m) = 1.2450247f0
loss_all(train_data, m) = 1.2329206f0
loss_all(train_data, m) = 1.2209444f0
loss_all(train_data, m) = 1.2091535f0
loss_all(train_data, m) = 1.1975375f0
loss_all(train_data, m) = 1.1860894f0
loss_all(train_data, m) = 1.1748438f0
loss_all(train_data, m) = 1.1637418f0
loss_all(train_data, m) = 1.152799f0
loss_all(train_data, m) = 1.142028f0
loss_all(train_data, m) = 1.1314083f0
loss_all(train_data, m) = 1.1209738f0
loss_all(train_data, m) = 1.1107329f0
loss_all(train_data, m) = 1.1006857f0
loss_all(train_data, m) = 1.0907996f0
loss_all(train_data, m) = 1.0811092f0
loss_all(train_data, m) = 1.0715741f0
loss_all(train_data, m) = 1.0622082f0
loss_all(train_data, m) = 1.0530168f0
loss_all(train_data, m) = 1.0439694f0
loss_all(train_data, m) = 1.0350765f0
los

┌ Info: Epoch 3
└ @ Main /Users/yasu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


0.8139554f0
loss_all(train_data, m) = 0.80894f0
loss_all(train_data, m) = 0.8040557f0
loss_all(train_data, m) = 0.7992726f0
loss_all(train_data, m) = 0.79458195f0
loss_all(train_data, m) = 0.7899621f0
loss_all(train_data, m) = 0.7854078f0
loss_all(train_data, m) = 0.7808892f0
loss_all(train_data, m) = 0.77639186f0
loss_all(train_data, m) = 0.77199864f0
loss_all(train_data, m) = 0.76767606f0
loss_all(train_data, m) = 0.76341987f0
loss_all(train_data, m) = 0.7592529f0
loss_all(train_data, m) = 0.75513333f0
loss_all(train_data, m) = 0.7510213f0
loss_all(train_data, m) = 0.7469472f0
loss_all(train_data, m) = 0.74290544f0
loss_all(train_data, m) = 0.7389348f0
loss_all(train_data, m) = 0.7350334f0
loss_all(train_data, m) = 0.73118407f0
loss_all(train_data, m) = 0.72737044f0
loss_all(train_data, m) = 0.7236196f0
loss_all(train_data, m) = 0.719921f0
loss_all(train_data, m) = 0.7162845f0
loss_all(train_data, m) = 0.7126949f0
loss_all(train_data, m) = 0.7091521f0
loss_all(train_data, m) = 0.7056

┌ Info: Epoch 4
└ @ Main /Users/yasu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss_all(train_data, m) = 0.61301637f0
loss_all(train_data, m) = 0.61055875f0
loss_all(train_data, m) = 0.60818744f0
loss_all(train_data, m) = 0.6058846f0
loss_all(train_data, m) = 0.6036269f0
loss_all(train_data, m) = 0.6014101f0
loss_all(train_data, m) = 0.5992038f0
loss_all(train_data, m) = 0.59702855f0
loss_all(train_data, m) = 0.59485626f0
loss_all(train_data, m) = 0.59265184f0
loss_all(train_data, m) = 0.59049815f0
loss_all(train_data, m) = 0.58837587f0
loss_all(train_data, m) = 0.5862844f0
loss_all(train_data, m) = 0.584246f0
loss_all(train_data, m) = 0.582225f0
loss_all(train_data, m) = 0.580182f0
loss_all(train_data, m) = 0.5781377f0
loss_all(train_data, m) = 0.5760979f0
loss_all(train_data, m) = 0.574087f0
loss_all(train_data, m) = 0.5721044f0
loss_all(train_data, m) = 0.5701434f0
loss_all(train_data, m) = 0.56819487f0
loss_all(train_data, m) = 0.5662806f0
loss_all(train_data, m) = 0.5643926f0
loss_all(train_data, m) = 0.5625468f0
loss_all(train_data, m) = 0.5607171f0
loss_al

┌ Info: Epoch 5
└ @ Main /Users/yasu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


0.5059028f0
loss_all(train_data, m) = 0.5045518f0
loss_all(train_data, m) = 0.5032498f0
loss_all(train_data, m) = 0.5019769f0
loss_all(train_data, m) = 0.50073165f0
loss_all(train_data, m) = 0.4994779f0
loss_all(train_data, m) = 0.49825177f0
loss_all(train_data, m) = 0.49702695f0
loss_all(train_data, m) = 0.49575102f0
loss_all(train_data, m) = 0.49450386f0
loss_all(train_data, m) = 0.49326932f0
loss_all(train_data, m) = 0.49204737f0
loss_all(train_data, m) = 0.49086636f0
loss_all(train_data, m) = 0.48969457f0
loss_all(train_data, m) = 0.4884946f0
loss_all(train_data, m) = 0.48728356f0
loss_all(train_data, m) = 0.48607063f0
loss_all(train_data, m) = 0.48486698f0
loss_all(train_data, m) = 0.4836737f0
loss_all(train_data, m) = 0.48249412f0
loss_all(train_data, m) = 0.4813194f0
loss_all(train_data, m) = 0.48017135f0
loss_all(train_data, m) = 0.47904104f0
loss_all(train_data, m) = 0.47794765f0
loss_all(train_data, m) = 0.47686383f0
loss_all(train_data, m) = 0.4758031f0
loss_all(train_data, 

┌ Info: Epoch 6
└ @ Main /Users/yasu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss_all(train_data, m) = 0.4440185f0
loss_all(train_data, m) = 0.4430972f0
loss_all(train_data, m) = 0.44222453f0
loss_all(train_data, m) = 0.44138908f0
loss_all(train_data, m) = 0.4405753f0
loss_all(train_data, m) = 0.4397865f0
loss_all(train_data, m) = 0.43898243f0
loss_all(train_data, m) = 0.4382046f0
loss_all(train_data, m) = 0.4374311f0
loss_all(train_data, m) = 0.4365997f0
loss_all(train_data, m) = 0.43578565f0
loss_all(train_data, m) = 0.4349782f0
loss_all(train_data, m) = 0.4341754f0
loss_all(train_data, m) = 0.4334087f0
loss_all(train_data, m) = 0.4326441f0
loss_all(train_data, m) = 0.43184876f0
loss_all(train_data, m) = 0.43103778f0
loss_all(train_data, m) = 0.43022648f0
loss_all(train_data, m) = 0.42941618f0
loss_all(train_data, m) = 0.4286092f0
loss_all(train_data, m) = 0.42781374f0
loss_all(train_data, m) = 0.42702064f0
loss_all(train_data, m) = 0.42625147f0
loss_all(train_data, m) = 0.4254975f0
loss_all(train_data, m) = 0.4247793f0
loss_all(train_data, m) = 0.4240668f0
l

┌ Info: Epoch 7
└ @ Main /Users/yasu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


0.4010238f0
loss_all(train_data, m) = 0.40040827f0
loss_all(train_data, m) = 0.39982337f0
loss_all(train_data, m) = 0.3992564f0
loss_all(train_data, m) = 0.39871255f0
loss_all(train_data, m) = 0.3981482f0
loss_all(train_data, m) = 0.3976123f0
loss_all(train_data, m) = 0.39708227f0
loss_all(train_data, m) = 0.39649406f0
loss_all(train_data, m) = 0.39591444f0
loss_all(train_data, m) = 0.39533764f0
loss_all(train_data, m) = 0.3947617f0
loss_all(train_data, m) = 0.39421895f0
loss_all(train_data, m) = 0.39367354f0
loss_all(train_data, m) = 0.39309356f0
loss_all(train_data, m) = 0.39249423f0
loss_all(train_data, m) = 0.3918971f0
loss_all(train_data, m) = 0.39129636f0
loss_all(train_data, m) = 0.39069796f0
loss_all(train_data, m) = 0.39011168f0
loss_all(train_data, m) = 0.38952753f0
loss_all(train_data, m) = 0.3889674f0
loss_all(train_data, m) = 0.38842198f0
loss_all(train_data, m) = 0.38791177f0
loss_all(train_data, m) = 0.3874065f0
loss_all(train_data, m) = 0.38691065f0
loss_all(train_data,

┌ Info: Epoch 8
└ @ Main /Users/yasu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


0.37053004f0
loss_all(train_data, m) = 0.37005958f0
loss_all(train_data, m) = 0.3696195f0
loss_all(train_data, m) = 0.3691975f0
loss_all(train_data, m) = 0.3687986f0
loss_all(train_data, m) = 0.3683791f0
loss_all(train_data, m) = 0.36798912f0
loss_all(train_data, m) = 0.3676081f0
loss_all(train_data, m) = 0.3671721f0
loss_all(train_data, m) = 0.36673823f0
loss_all(train_data, m) = 0.36630508f0
loss_all(train_data, m) = 0.36586913f0
loss_all(train_data, m) = 0.36546358f0
loss_all(train_data, m) = 0.36505082f0
loss_all(train_data, m) = 0.36460167f0
loss_all(train_data, m) = 0.36413196f0
loss_all(train_data, m) = 0.3636661f0
loss_all(train_data, m) = 0.3631949f0
loss_all(train_data, m) = 0.36272463f0
loss_all(train_data, m) = 0.36226708f0
loss_all(train_data, m) = 0.36180946f0
loss_all(train_data, m) = 0.36137506f0
loss_all(train_data, m) = 0.36095774f0
loss_all(train_data, m) = 0.36057445f0
loss_all(train_data, m) = 0.3601955f0
loss_all(train_data, m) = 0.35982203f0
loss_all(train_data, 

┌ Info: Epoch 9
└ @ Main /Users/yasu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss_all(train_data, m) = 0.34758905f0
loss_all(train_data, m) = 0.34718093f0
loss_all(train_data, m) = 0.34680343f0
loss_all(train_data, m) = 0.34645495f0
loss_all(train_data, m) = 0.34612507f0
loss_all(train_data, m) = 0.34581918f0
loss_all(train_data, m) = 0.34549105f0
loss_all(train_data, m) = 0.3451957f0
loss_all(train_data, m) = 0.34491035f0
loss_all(train_data, m) = 0.344577f0
loss_all(train_data, m) = 0.34424326f0
loss_all(train_data, m) = 0.34390843f0
loss_all(train_data, m) = 0.34356788f0
loss_all(train_data, m) = 0.3432545f0
loss_all(train_data, m) = 0.34292975f0
loss_all(train_data, m) = 0.34256428f0
loss_all(train_data, m) = 0.3421744f0
loss_all(train_data, m) = 0.3417915f0
loss_all(train_data, m) = 0.34140337f0
loss_all(train_data, m) = 0.34101942f0
loss_all(train_data, m) = 0.34064895f0
loss_all(train_data, m) = 0.3402781f0
loss_all(train_data, m) = 0.3399288f0
loss_all(train_data, m) = 0.33959642f0
loss_all(train_data, m) = 0.33929637f0
loss_all(train_data, m) = 0.33899

┌ Info: Epoch 10
└ @ Main /Users/yasu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


loss_all(train_data, m) = 0.32891357f0
loss_all(train_data, m) = 0.32857573f0
loss_all(train_data, m) = 0.3282662f0
loss_all(train_data, m) = 0.32798526f0
loss_all(train_data, m) = 0.32772344f0
loss_all(train_data, m) = 0.32748604f0
loss_all(train_data, m) = 0.32722554f0
loss_all(train_data, m) = 0.32699957f0
loss_all(train_data, m) = 0.3267869f0
loss_all(train_data, m) = 0.32652876f0
loss_all(train_data, m) = 0.32626396f0
loss_all(train_data, m) = 0.3259951f0
loss_all(train_data, m) = 0.325717f0
loss_all(train_data, m) = 0.32546526f0
loss_all(train_data, m) = 0.32519805f0
loss_all(train_data, m) = 0.32488668f0
loss_all(train_data, m) = 0.32454887f0
loss_all(train_data, m) = 0.3242214f0
loss_all(train_data, m) = 0.32389003f0
loss_all(train_data, m) = 0.32356456f0
loss_all(train_data, m) = 0.32325408f0
loss_all(train_data, m) = 0.32294258f0
loss_all(train_data, m) = 0.3226528f0
loss_all(train_data, m) = 0.3223817f0
loss_all(train_data, m) = 0.3221424f0
loss_all(train_data, m) = 0.321904

0.9150939451382695