In [1]:
import Pkg
Pkg.activate(".")

[32m[1m  Activating[22m[39m project at `~/Documents/programming/BME-574-2024`


In [3]:
Pkg.add(["Lux", "ADTypes", "MLUtils", "Optimisers", "Zygote", "OneHotArrays", "Random", "Statistics", "Printf", "MLDatasets"])

[32m[1m    Updating[22m[39m registry at `~/.julia/registries/JSMLComponents.toml`
[32m[1m    Updating[22m[39m registry at `~/.julia/registries/JuliaComputingRegistry.toml`
[32m[1m    Updating[22m[39m registry at `~/.julia/registries/JuliaHubRegistry.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/Documents/programming/BME-574-2024/Project.toml`
[32m[1m  No Changes[22m[39m to `~/Documents/programming/BME-574-2024/Manifest.toml`
[32m[1mPrecompiling[22m[39m project...
[32m  ✓ [39m[90mXSLT_jll[39m
[32m  ✓ [39m[90mSparseMatrixColorings → SparseMatrixColoringsColorsExt[39m
[32m  ✓ [39mLsqFit
[32m  ✓ [39m[90mSteadyStateDiffEq[39m
[32m  ✓ [39m[90mPlots → UnitfulExt[39m
[32m  ✓ [39m[90mLinearSolve → LinearSolveKernelAbstractionsExt[39m
[32m  ✓ [39m[90mRootedTrees → PlotsExt[39m
[32m  ✓ [39m[90mNonlinearSolve → NonlinearSolveNLsolveExt[39m
[32m  ✓ [39m[90mSundials[39m
[32m  ✓ [39m[90mStochast

In [6]:
using Lux, ADTypes, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf
using MLDatasets: MNIST

In [7]:
function loadmnist(batchsize, train_split)
    # Load MNIST
    N = 60000
    dataset = MNIST(; split=:train)
    imgs = dataset.features[:, :, 1:N]
    labels_raw = dataset.targets[1:N]

    # Process images into (H, W, C, BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false))
end

loadmnist (generic function with 1 method)

In [8]:
lux_model = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3),
    Chain(Dense(256 => 128, relu), Dense(128 => 84, relu),Dense(84 => 10)))

Chain(
    layer_1 = Conv((5, 5), 1 => 6, relu),  [90m# 156 parameters[39m
    layer_2 = MaxPool((2, 2)),
    layer_3 = Conv((5, 5), 6 => 16, relu),  [90m# 2_416 parameters[39m
    layer_4 = MaxPool((2, 2)),
    layer_5 = FlattenLayer{Static.StaticInt{3}}(static(3)),
    layer_6 = Chain(
        layer_1 = Dense(256 => 128, relu),  [90m# 32_896 parameters[39m
        layer_2 = Dense(128 => 84, relu),  [90m# 10_836 parameters[39m
        layer_3 = Dense(84 => 10),      [90m# 850 parameters[39m
    ),
) [90m        # Total: [39m47_154 parameters,
[90m          #        plus [39m0 states.

In [32]:
lux_model2 = Chain(Conv((3, 3), 1 => 32, relu), Conv((3, 3), 32 => 64, relu), MaxPool((3, 3)),
    Dropout(0.5), FlattenLayer(3),
    Chain(Dense( 4096 => 250, relu), Dense(250 => 10)))

Chain(
    layer_1 = Conv((3, 3), 1 => 32, relu),  [90m# 320 parameters[39m
    layer_2 = Conv((3, 3), 32 => 64, relu),  [90m# 18_496 parameters[39m
    layer_3 = MaxPool((3, 3)),
    layer_4 = Dropout(0.5),
    layer_5 = FlattenLayer{Static.StaticInt{3}}(static(3)),
    layer_6 = Chain(
        layer_1 = Dense(4096 => 250, relu),  [90m# 1_024_250 parameters[39m
        layer_2 = Dense(250 => 10),     [90m# 2_510 parameters[39m
    ),
) [90m        # Total: [39m1_045_576 parameters,
[90m          #        plus [39m2 states.

In [9]:
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(Array(first(model(x, ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

accuracy (generic function with 1 method)

In [10]:
function train(model; rng=Xoshiro(0), kwargs...)
    train_dataloader, test_dataloader = loadmnist(128, 0.9)
    ps, st = Lux.setup(rng, model)

    train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))

    ### Warmup the model
    x_proto = randn(rng, Float32, 28, 28, 1, 1)
    y_proto = onehotbatch([1], 0:9)
    Training.compute_gradients(AutoZygote(), loss, (x_proto, y_proto), train_state)

    ### Lets train the model
    nepochs = 10
    tr_acc, te_acc = 0.0, 0.0
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            gs, _, _, train_state = Training.single_train_step!(
                AutoZygote(), loss, (x, y), train_state)
        end
        ttime = time() - stime

        tr_acc = accuracy(
            model, train_state.parameters, train_state.states, train_dataloader) * 100
        te_acc = accuracy(
            model, train_state.parameters, train_state.states, test_dataloader) * 100

        @printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
                 %.2f%%\n" epoch nepochs ttime tr_acc te_acc
    end

    return tr_acc, te_acc
end

train (generic function with 1 method)

In [11]:
tr_acc, te_acc = train(lux_model)

[ 1/10] 	 Time 8.30s 	 Training Accuracy: 92.06% 	 Test Accuracy: 93.95%
[ 2/10] 	 Time 5.44s 	 Training Accuracy: 95.21% 	 Test Accuracy: 95.67%
[ 3/10] 	 Time 5.59s 	 Training Accuracy: 96.50% 	 Test Accuracy: 96.68%
[ 4/10] 	 Time 5.62s 	 Training Accuracy: 97.37% 	 Test Accuracy: 97.38%
[ 5/10] 	 Time 5.55s 	 Training Accuracy: 97.82% 	 Test Accuracy: 97.48%
[ 6/10] 	 Time 5.68s 	 Training Accuracy: 97.91% 	 Test Accuracy: 97.57%
[ 7/10] 	 Time 5.56s 	 Training Accuracy: 98.01% 	 Test Accuracy: 97.63%
[ 8/10] 	 Time 5.50s 	 Training Accuracy: 98.33% 	 Test Accuracy: 97.88%
[ 9/10] 	 Time 5.58s 	 Training Accuracy: 98.62% 	 Test Accuracy: 98.13%
[10/10] 	 Time 5.52s 	 Training Accuracy: 98.64% 	 Test Accuracy: 97.95%


(98.63518518518518, 97.95)

In [52]:
tr_acc2, te_acc2 = train(lux_model2)

[ 1/10] 	 Time 74.42s 	 Training Accuracy: 97.84% 	 Test Accuracy: 98.37%
[ 2/10] 	 Time 71.04s 	 Training Accuracy: 98.69% 	 Test Accuracy: 98.72%
[ 3/10] 	 Time 71.29s 	 Training Accuracy: 99.09% 	 Test Accuracy: 98.90%
[ 4/10] 	 Time 74.21s 	 Training Accuracy: 99.35% 	 Test Accuracy: 99.00%
[ 5/10] 	 Time 71.32s 	 Training Accuracy: 99.50% 	 Test Accuracy: 99.05%
[ 6/10] 	 Time 72.00s 	 Training Accuracy: 99.52% 	 Test Accuracy: 99.13%
[ 7/10] 	 Time 72.08s 	 Training Accuracy: 99.48% 	 Test Accuracy: 98.85%
[ 8/10] 	 Time 74.15s 	 Training Accuracy: 99.71% 	 Test Accuracy: 99.10%
[ 9/10] 	 Time 73.69s 	 Training Accuracy: 99.81% 	 Test Accuracy: 99.12%
[10/10] 	 Time 66.82s 	 Training Accuracy: 99.72% 	 Test Accuracy: 99.12%


(99.71666666666667, 99.11666666666666)

In [34]:
f(x,y) = 2*x + 3*y

f (generic function with 1 method)

In [36]:
@code_llvm f(2,3)

[90m;  @ In[34]:1 within `f`[39m
[95mdefine[39m [36mi64[39m [93m@julia_f_23589[39m[33m([39m[36mi64[39m [95msignext[39m [0m%0[0m, [36mi64[39m [95msignext[39m [0m%1[33m)[39m [0m#0 [33m{[39m
[91mtop:[39m
[90m; ┌ @ int.jl:88 within `*`[39m
   [0m%2 [0m= [96m[1mshl[22m[39m [36mi64[39m [0m%0[0m, [33m1[39m
   [0m%3 [0m= [96m[1mmul[22m[39m [36mi64[39m [0m%1[0m, [33m3[39m
[90m; └[39m
[90m; ┌ @ int.jl:87 within `+`[39m
   [0m%4 [0m= [96m[1madd[22m[39m [36mi64[39m [0m%3[0m, [0m%2
   [96m[1mret[22m[39m [36mi64[39m [0m%4
[90m; └[39m
[33m}[39m
