# Julia 深度學習：卷積神經網路模型簡介

In [1]:
using Flux
using Flux.Data: DataLoader
using Flux.Optimise: Optimiser, WeightDecay
using Flux: onehotbatch, onecold, logitcrossentropy
using Statistics, Random
import ProgressMeter
import MLDatasets
using CUDAapi

In [2]:
η = 5e-4             # learning rate
λ = 0                # L2 regularizer param, implemented as weight decay
epochs = 20          # number of epochs
batchsize = 256      # batch size

## 使用 CUDA

In [3]:
use_cuda = CUDAapi.has_cuda_gpu()
if use_cuda
    device = gpu
    println("Training on GPU")
else
    device = cpu
    println("Training on CPU")
end

Training on GPU


## 載入資料

In [4]:
function get_data(batchsize=256)
    xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
    xtest , ytest  = MLDatasets.MNIST.testdata(Float32)

    xtrain = reshape(xtrain, 28, 28, 1, :)
    xtest  = reshape(xtest, 28, 28, 1, :)

    ytrain = onehotbatch(ytrain, 0:9)
    ytest  = onehotbatch(ytest, 0:9)

    train_loader = DataLoader(xtrain, ytrain, batchsize=batchsize, shuffle=true)
    test_loader  = DataLoader(xtest, ytest,  batchsize=batchsize)
    
    return train_loader, test_loader
end

train_loader, test_loader = get_data(batchsize);

## CNN 模型

In [5]:
# write your model here
num_params(model) = sum(length, Flux.params(model))

function buildModel(imgsize=(28,28,1))
    return Chain(
    x -> reshape(x, imgsize..., :),
    Conv((3, 3), 1=>16, pad=(1,1), relu),
    MaxPool((2,2)),
    Conv((3, 3), 16=>32, pad=(1,1), relu),
    MaxPool((2,2)),
    Conv((3, 3), 32=>32, pad=(1,1), relu),
    MaxPool((2,2)),
    flatten,
    Dense(288, 10),
    softmax)
end

model = buildModel() |> device;
println("CNNs model: $(num_params(model)) trainable params");

CNNs model: 16938 trainable params


In [6]:
fake_input = randn((28, 28, 1)) |> device
fake_output = model(fake_input)

10×1 CuArrays.CuArray{Float32,2,Nothing}:
 0.08607546
 0.08095291
 0.119251266
 0.084274895
 0.08514476
 0.11652309
 0.10538427
 0.10048574
 0.09650776
 0.12539984

## 損失函數

In [7]:
loss(ŷ, y) = logitcrossentropy(ŷ, y)

round4(x) = round(x, digits=4)

function calc_loss_accuracy(loader, model, device)
    l = 0f0
    acc = 0
    ntot = 0
    for (x, y) in loader
        x, y = x |> device, y |> device
        ŷ = model(x)
        l += loss(ŷ, y) * size(x)[end]        
        acc += sum(onecold(ŷ |> cpu) .== onecold(y |> cpu))
        ntot += size(x)[end]
    end
    return (loss = l/ntot |> round4, acc = acc/ntot*100 |> round4)
end

calc_loss_accuracy (generic function with 1 method)

## Callback 函式

In [8]:
# callback function
function callback(epoch)
    train = calc_loss_accuracy(train_loader, model, device)
    test = calc_loss_accuracy(test_loader, model, device)        
    println("Epoch: $epoch   Train: $(train)   Test: $(test)")
end

callback (generic function with 1 method)

## 模型訓練

In [9]:
# define optimizer
opt = ADAM(η) 
if λ > 0 
    opt = Optimiser(opt, WeightDecay(λ))
end

In [10]:
params_model = Flux.params(model)

println("Starting Training")
callback(0)

for epoch in 1:epochs
    progress = ProgressMeter.Progress(length(train_loader))
    
    # update parameters
    for (x, y) in train_loader
        x = x |> device
        y = y |> device
        grads_model = Flux.gradient(params_model) do
            ŷ = model(x)
            loss(ŷ, y)
        end
        Flux.Optimise.update!(opt, params_model, grads_model)
        ProgressMeter.next!(progress)   # comment out for no progress bar
    end
    
    #  logging
    if epoch % 5 == 0
        callback(epoch)
    end
end

Starting Training
Epoch: 0   Train: (loss = 2.3025f0, acc = 10.8767)   Test: (loss = 2.3024f0, acc = 11.33)

[32mProgress:   0%|█                                        |  ETA: 2:26:03[39m




[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:40[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m


Epoch: 5   Train: (loss = 1.5738f0, acc = 88.895)   Test: (loss = 1.5719f0, acc = 89.04)

[32mProgress:   4%|██                                       |  ETA: 0:00:02[39m




[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m


Epoch: 10   Train: (loss = 1.5641f0, acc = 89.74)   Test: (loss = 1.5645f0, acc = 89.67)

[32mProgress:   4%|██                                       |  ETA: 0:00:02[39m




[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m


Epoch: 15   Train: (loss = 1.5621f0, acc = 89.93)   Test: (loss = 1.5635f0, acc = 89.77)

[32mProgress:   4%|██                                       |  ETA: 0:00:02[39m




[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m


Epoch: 20   Train: (loss = 1.5579f0, acc = 90.315)   Test: (loss = 1.5593f0, acc = 90.21)


## 模型評估

In [11]:
test = calc_loss_accuracy(test_loader, model, device)        
println("Test: $(test)")

Test: (loss = 1.5593f0, acc = 90.21)
