## Add the Fashion MNIST model to the Flux model-zoo ##

1. Write Flux model
2. Benchmark it against Pytorch, Tensorflow equivalents

### 1. Import Libraries ###
Note: Due to compatability issues on my laptop, I cannot properly install CUDA (hence Cudnn) and cannot load the arrays to my GPU :(

In [1]:
using MLDatasets
using Flux, Statistics
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using Printf, BSON

### 2. Data Preparation ###
I load the train and test sets of FashionMNIST to accessible formats, and define a function to partition data into mini-batches.

In [2]:
train_x, train_y = FashionMNIST.traindata()
test_x, test_y = FashionMNIST.testdata()
train_x, test_x = [train_x[:, :, i] for i in 1:size(train_x,3)], [test_x[:, :, i] for i in 1:size(test_x,3)]


# Mini-batch function:
function mb(img, label, indexx) 
    mb_x = Array{Float32}(undef, size(img[1])..., 1, length(indexx)) # placeholder
    for i in 1:length(indexx)
        mb_x[:, :, :, i] = Float32.(img[indexx[i]]) # Asigns images into batches of 64
    end
    # This encodes labels into one-hot vector representation
    mb_y = onehotbatch(label[indexx], 0:9)
    return (mb_x, mb_y)
end

mb (generic function with 1 method)

### 3. Define the Model ###
I construct my ConvNet and define Mean Squared Error as the cost function.

In [3]:
ConvNet = Chain(
    # Input: (28, 28, 1, 64)
    
    # After Conv1: (14, 14, 16, 64)
    Conv((3, 3), 1=>16, pad=(1,1), relu),
    MaxPool((2,2), stride=(2,2)),

    # After Conv2: (7, 7, 32, 64)
    Conv((3, 3), 16=>32, pad=(1,1), relu),
    MaxPool((2,2), stride=(2,2)),

    # After Conv3: (3, 3, 64, 64)
    Conv((3, 3), 32=>64, pad=(1,1), relu),
    MaxPool((2,2), stride=(2,2)),

    # Flatten to (64, 576)
    x -> reshape(x, :, size(x, 4)),
    Dense(576, 72),
    x -> relu.(x),
    Dense(72, 10),
    x -> relu.(x),
    
    # Get probability predictions, corresponding to each class
    softmax,
)

# MSE Loss
function loss(x, y)
    x_aug = x .+ 0.1f0*randn(eltype(x), size(x)) 
    # Augmenting x helps increase inter-class variance (decision boundary) >> improves model discriminativeness
    y_hat = ConvNet(x_aug)
    return crossentropy(y_hat, y)
end

loss (generic function with 1 method)

### 4. Define Hyperparameters ###
I set the mini-batch size as 64, learning-rate as 0.001, epoch number as 200; and use the Adam optimiser.

In [4]:
mb_size = 64
mb_idxs = partition(1:length(train_x), mb_size)
trainset = [mb(train_x, train_y, i) for i in mb_idxs]
testset = mb(test_x, test_y, 1:length(test_x))
deleteat!(trainset, 938) # dims error

# pass data to GPU for acceleration
#trainset, testset, ConvNet = gpu.(trainset), gpu.(testset), gpu(ConvNet)

accuracy(x, y) = mean(onecold(ConvNet(x)) .== onecold(y))
optimiser = ADAM(0.001)
best_acc, last_improvement, epoch_num, threshold = 0.0, 0, 100, 0.95

(0.0, 0, 100, 0.95)

In [5]:
# Sanity Check: (classes=10, minibatch=64) array
ConvNet(trainset[1][1])

10×64 Array{Float32,2}:
 0.105307   0.108713   0.102776   …  0.108729   0.102043   0.10187  
 0.107922   0.101713   0.100319      0.104436   0.109957   0.10227  
 0.0980129  0.0986968  0.0995453     0.0983543  0.0976855  0.0992966
 0.0980129  0.0986968  0.0995453     0.0983543  0.0976855  0.0992966
 0.0980129  0.0986968  0.0995453     0.0983543  0.0976855  0.0992966
 0.0980129  0.0986968  0.0995453  …  0.0983543  0.101364   0.0996767
 0.0980129  0.0986968  0.0995453     0.0983543  0.0976855  0.0992966
 0.100681   0.0986968  0.100088      0.0983543  0.100523   0.100404 
 0.0980129  0.0986968  0.0995453     0.0983543  0.0976855  0.0992966
 0.0980129  0.0986968  0.0995453     0.0983543  0.0976855  0.0992966

### 5. Begin Training & Testing ###

In [6]:
time = 0.0

for epoch in 1:epoch_num
    global best_acc, last_improvement
    
    tic = time_ns()
    # Train for a single epoch
    Flux.train!(loss, params(ConvNet), trainset, optimiser)
    toc = time_ns()
    time += toc-tic
    println(time/1.0e9)
    # Calculate accuracy:
    acc = accuracy(testset...)
    @info(@sprintf("[%d]: Test accuracy: %.4f", epoch, acc))

    # Early exit at 95% accuracy
    if acc >= threshold
        @info(" -> Early-exiting: We reached our target accuracy of 90.0%")
        break
    end

    if acc >= best_acc
        @info(" -> New best accuracy! Saving model out to fmnist_flux.bson")
        BSON.@save joinpath("fmnist_flux_4.bson") ConvNet epoch acc
        best_acc, last_improvement = acc, epoch
    end

    # Learning rate decay to dampen oscillations
    if epoch - last_improvement >= 15 && optimiser.eta > 1e-06 #1e-06
        optimiser.eta /= 10.0
        @warn(" -> Haven't improved in a while, dropping learning rate to $(optimiser.eta)!")
        last_improvement = epoch
    end
    
    if epoch - last_improvement >= 25
        @warn(" -> We're calling this converged.")
        break
    end
end

96.1566225


┌ Info: [1]: Test accuracy: 0.6944
└ @ Main In[6]:14
┌ Info:  -> New best accuracy! Saving model out to fmnist_flux.bson
└ @ Main In[6]:23


183.6960653


┌ Info: [2]: Test accuracy: 0.7135
└ @ Main In[6]:14
┌ Info:  -> New best accuracy! Saving model out to fmnist_flux.bson
└ @ Main In[6]:23


270.4484071


┌ Info: [3]: Test accuracy: 0.7172
└ @ Main In[6]:14
┌ Info:  -> New best accuracy! Saving model out to fmnist_flux.bson
└ @ Main In[6]:23


356.0775229


┌ Info: [4]: Test accuracy: 0.7200
└ @ Main In[6]:14
┌ Info:  -> New best accuracy! Saving model out to fmnist_flux.bson
└ @ Main In[6]:23


440.533314201


┌ Info: [5]: Test accuracy: 0.7119
└ @ Main In[6]:14


524.9961107


┌ Info: [6]: Test accuracy: 0.7162
└ @ Main In[6]:14


609.4269898


┌ Info: [7]: Test accuracy: 0.7201
└ @ Main In[6]:14
┌ Info:  -> New best accuracy! Saving model out to fmnist_flux.bson
└ @ Main In[6]:23


694.5362108


┌ Info: [8]: Test accuracy: 0.7182
└ @ Main In[6]:14


779.610495601


┌ Info: [9]: Test accuracy: 0.7210
└ @ Main In[6]:14
┌ Info:  -> New best accuracy! Saving model out to fmnist_flux.bson
└ @ Main In[6]:23


862.8768959


┌ Info: [10]: Test accuracy: 0.7224
└ @ Main In[6]:14
┌ Info:  -> New best accuracy! Saving model out to fmnist_flux.bson
└ @ Main In[6]:23


947.4686651


┌ Info: [11]: Test accuracy: 0.7191
└ @ Main In[6]:14


1031.7048632


┌ Info: [12]: Test accuracy: 0.7188
└ @ Main In[6]:14


1118.7910213


┌ Info: [13]: Test accuracy: 0.7183
└ @ Main In[6]:14


1204.0653898


┌ Info: [14]: Test accuracy: 0.7197
└ @ Main In[6]:14


1292.711714701


┌ Info: [15]: Test accuracy: 0.7175
└ @ Main In[6]:14


1377.480040201


┌ Info: [16]: Test accuracy: 0.7250
└ @ Main In[6]:14
┌ Info:  -> New best accuracy! Saving model out to fmnist_flux.bson
└ @ Main In[6]:23


1464.063865801


┌ Info: [17]: Test accuracy: 0.7248
└ @ Main In[6]:14


1549.036699501


┌ Info: [18]: Test accuracy: 0.7211
└ @ Main In[6]:14


1633.720768902


┌ Info: [19]: Test accuracy: 0.7184
└ @ Main In[6]:14


1735.948451502


┌ Info: [20]: Test accuracy: 0.7186
└ @ Main In[6]:14


1903.038503301


┌ Info: [21]: Test accuracy: 0.7166
└ @ Main In[6]:14


2065.6680757


┌ Info: [22]: Test accuracy: 0.7221
└ @ Main In[6]:14


2231.0625384


┌ Info: [23]: Test accuracy: 0.7214
└ @ Main In[6]:14


2396.656911199


┌ Info: [24]: Test accuracy: 0.7158
└ @ Main In[6]:14


2561.761726698


┌ Info: [25]: Test accuracy: 0.7214
└ @ Main In[6]:14


2730.724183498


┌ Info: [26]: Test accuracy: 0.7191
└ @ Main In[6]:14


2888.426339998


┌ Info: [27]: Test accuracy: 0.7242
└ @ Main In[6]:14


3053.711927998


┌ Info: [28]: Test accuracy: 0.7159
└ @ Main In[6]:14


3218.924178298


┌ Info: [29]: Test accuracy: 0.7186
└ @ Main In[6]:14


3384.295225198


┌ Info: [30]: Test accuracy: 0.7194
└ @ Main In[6]:14


3546.795807598


┌ Info: [31]: Test accuracy: 0.7209
└ @ Main In[6]:14
└ @ Main In[6]:31


3705.591788798


┌ Info: [32]: Test accuracy: 0.7307
└ @ Main In[6]:14
┌ Info:  -> New best accuracy! Saving model out to fmnist_flux.bson
└ @ Main In[6]:23


3864.098632098


┌ Info: [33]: Test accuracy: 0.7309
└ @ Main In[6]:14
┌ Info:  -> New best accuracy! Saving model out to fmnist_flux.bson
└ @ Main In[6]:23


4022.753495397


┌ Info: [34]: Test accuracy: 0.7310
└ @ Main In[6]:14
┌ Info:  -> New best accuracy! Saving model out to fmnist_flux.bson
└ @ Main In[6]:23


4185.281503497


┌ Info: [35]: Test accuracy: 0.7294
└ @ Main In[6]:14


4332.523490697


┌ Info: [36]: Test accuracy: 0.7298
└ @ Main In[6]:14


4481.257805997


┌ Info: [37]: Test accuracy: 0.7296
└ @ Main In[6]:14


4627.256208697


┌ Info: [38]: Test accuracy: 0.7298
└ @ Main In[6]:14


4774.049273496


┌ Info: [39]: Test accuracy: 0.7290
└ @ Main In[6]:14


4863.594029196


┌ Info: [40]: Test accuracy: 0.7293
└ @ Main In[6]:14


4946.587481395


┌ Info: [41]: Test accuracy: 0.7297
└ @ Main In[6]:14


5029.790838996


┌ Info: [42]: Test accuracy: 0.7281
└ @ Main In[6]:14


5112.792836195


┌ Info: [43]: Test accuracy: 0.7290
└ @ Main In[6]:14


5195.511404095


┌ Info: [44]: Test accuracy: 0.7291
└ @ Main In[6]:14


5278.021792795


┌ Info: [45]: Test accuracy: 0.7296
└ @ Main In[6]:14


5361.827755694


┌ Info: [46]: Test accuracy: 0.7308
└ @ Main In[6]:14


5446.289305794


┌ Info: [47]: Test accuracy: 0.7300
└ @ Main In[6]:14


5531.293768593


┌ Info: [48]: Test accuracy: 0.7294
└ @ Main In[6]:14


5616.547103794


┌ Info: [49]: Test accuracy: 0.7316
└ @ Main In[6]:14
┌ Info:  -> New best accuracy! Saving model out to fmnist_flux.bson
└ @ Main In[6]:23


5703.114575595


┌ Info: [50]: Test accuracy: 0.7312
└ @ Main In[6]:14


5788.996474295


┌ Info: [51]: Test accuracy: 0.7300
└ @ Main In[6]:14


5931.497084196


┌ Info: [52]: Test accuracy: 0.7300
└ @ Main In[6]:14


6110.474124895


┌ Info: [53]: Test accuracy: 0.7291
└ @ Main In[6]:14


6276.605818995


┌ Info: [54]: Test accuracy: 0.7291
└ @ Main In[6]:14


6434.991721794


┌ Info: [55]: Test accuracy: 0.7295
└ @ Main In[6]:14


6537.844718594


┌ Info: [56]: Test accuracy: 0.7287
└ @ Main In[6]:14


6690.171750094


┌ Info: [57]: Test accuracy: 0.7271
└ @ Main In[6]:14


6842.497078294


┌ Info: [58]: Test accuracy: 0.7291
└ @ Main In[6]:14


6986.831183194


┌ Info: [59]: Test accuracy: 0.7288
└ @ Main In[6]:14


7125.190337795


┌ Info: [60]: Test accuracy: 0.7300
└ @ Main In[6]:14


7264.052668295


┌ Info: [61]: Test accuracy: 0.7307
└ @ Main In[6]:14


7410.103209495


┌ Info: [62]: Test accuracy: 0.7288
└ @ Main In[6]:14


7551.185552395


┌ Info: [63]: Test accuracy: 0.7291
└ @ Main In[6]:14


7695.087532396


┌ Info: [64]: Test accuracy: 0.7291
└ @ Main In[6]:14
└ @ Main In[6]:31


7836.014751196


┌ Info: [65]: Test accuracy: 0.7283
└ @ Main In[6]:14


7975.258991297


┌ Info: [66]: Test accuracy: 0.7280
└ @ Main In[6]:14


8103.636392597


┌ Info: [67]: Test accuracy: 0.7274
└ @ Main In[6]:14


8237.871211597


┌ Info: [68]: Test accuracy: 0.7284
└ @ Main In[6]:14


8377.822132796


┌ Info: [69]: Test accuracy: 0.7283
└ @ Main In[6]:14


8519.502455496


┌ Info: [70]: Test accuracy: 0.7281
└ @ Main In[6]:14


8668.531965696


┌ Info: [71]: Test accuracy: 0.7281
└ @ Main In[6]:14


8820.021408596


┌ Info: [72]: Test accuracy: 0.7282
└ @ Main In[6]:14


8972.208271897


┌ Info: [73]: Test accuracy: 0.7284
└ @ Main In[6]:14


9100.730591198


┌ Info: [74]: Test accuracy: 0.7281
└ @ Main In[6]:14


9231.975302398


┌ Info: [75]: Test accuracy: 0.7277
└ @ Main In[6]:14


9367.282362198


┌ Info: [76]: Test accuracy: 0.7280
└ @ Main In[6]:14


9495.353638599


┌ Info: [77]: Test accuracy: 0.7282
└ @ Main In[6]:14


9630.440631799


┌ Info: [78]: Test accuracy: 0.7277
└ @ Main In[6]:14


9769.775296198


┌ Info: [79]: Test accuracy: 0.7282
└ @ Main In[6]:14
└ @ Main In[6]:31


9919.344010098


┌ Info: [80]: Test accuracy: 0.7278
└ @ Main In[6]:14


10086.310693398


┌ Info: [81]: Test accuracy: 0.7279
└ @ Main In[6]:14


10226.405282998


┌ Info: [82]: Test accuracy: 0.7277
└ @ Main In[6]:14


10350.009525098


┌ Info: [83]: Test accuracy: 0.7277
└ @ Main In[6]:14


10432.972572998


┌ Info: [84]: Test accuracy: 0.7276
└ @ Main In[6]:14


10515.859137298


┌ Info: [85]: Test accuracy: 0.7279
└ @ Main In[6]:14


10599.251768097


┌ Info: [86]: Test accuracy: 0.7278
└ @ Main In[6]:14


10682.560009098


┌ Info: [87]: Test accuracy: 0.7279
└ @ Main In[6]:14


10766.275531197


┌ Info: [88]: Test accuracy: 0.7279
└ @ Main In[6]:14


10849.323894797


┌ Info: [89]: Test accuracy: 0.7277
└ @ Main In[6]:14


10934.484979598


┌ Info: [90]: Test accuracy: 0.7277
└ @ Main In[6]:14


11024.546242198


┌ Info: [91]: Test accuracy: 0.7278
└ @ Main In[6]:14


11107.246617797


┌ Info: [92]: Test accuracy: 0.7278
└ @ Main In[6]:14


11190.337343897


┌ Info: [93]: Test accuracy: 0.7277
└ @ Main In[6]:14


11275.413123296


┌ Info: [94]: Test accuracy: 0.7277
└ @ Main In[6]:14
└ @ Main In[6]:31


11364.359599896


┌ Info: [95]: Test accuracy: 0.7277
└ @ Main In[6]:14


11449.792366397


┌ Info: [96]: Test accuracy: 0.7277
└ @ Main In[6]:14


11534.143077696


┌ Info: [97]: Test accuracy: 0.7277
└ @ Main In[6]:14


11618.230382796


┌ Info: [98]: Test accuracy: 0.7277
└ @ Main In[6]:14


11702.651708896


┌ Info: [99]: Test accuracy: 0.7277
└ @ Main In[6]:14


11789.689309796


┌ Info: [100]: Test accuracy: 0.7277
└ @ Main In[6]:14


In [7]:
time /= 1.0e9
print(time)

11789.689309796