In [1]:
import Pkg
Pkg.activate(".")

[32m[1m  Activating[22m[39m project at `~/Documents/programming/BME-574-2024`


In [9]:
Pkg.add(["Lux", "ADTypes", "MLUtils", "Optimisers", "Zygote", "OneHotArrays", "Random", "Statistics", "Printf", "MLDatasets","Enzyme"])

[32m[1m   Resolving[22m[39m package versions...
[32m[1m   Installed[22m[39m OptimizationOptimisers ───── v0.3.4
[32m[1m   Installed[22m[39m Enzyme_jll ───────────────── v0.0.159+0
[32m[1m   Installed[22m[39m GLMakie ──────────────────── v0.10.15
[32m[1m   Installed[22m[39m OptimizationPolyalgorithms ─ v0.3.0
[32m[1m   Installed[22m[39m Plots ────────────────────── v1.40.7
[32m[1m   Installed[22m[39m Enzyme ───────────────────── v0.13.13
[32m[1m   Installed[22m[39m GPUCompiler ──────────────── v1.0.1
[32m[1m    Updating[22m[39m `~/Documents/programming/BME-574-2024/Project.toml`
  [90m[13f3f980] [39m[93m↑ CairoMakie v0.12.13 ⇒ v0.12.15[39m
  [90m[717857b8] [39m[93m↑ DSP v0.7.9 ⇒ v0.7.10[39m
  [90m[7da242da] [39m[92m+ Enzyme v0.13.13[39m
  [90m[f6369f11] [39m[93m↑ ForwardDiff v0.10.36 ⇒ v0.10.37[39m
  [90m[e9467ef8] [39m[93m↑ GLMakie v0.10.13 ⇒ v0.10.15[39m
  [90m[7f7a1694] [39m[93m↑ Optimization v3.26.3 ⇒ v4.0.5[39m
  [90m[363

In [2]:
using Lux, ADTypes, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf, Enzyme
using MLDatasets: MNIST

[32m[1mPrecompiling[22m[39m LuxMLUtilsExt
[32m  ✓ [39m[90mWeightInitializers → WeightInitializersGPUArraysExt[39m
[32m  ✓ [39m[90mLux → LuxMLUtilsExt[39m
[32m  ✓ [39m[90mLux → LuxZygoteExt[39m
  3 dependencies successfully precompiled in 3 seconds. 241 already precompiled.
[32m[1mPrecompiling[22m[39m MLDatasets
[32m  ✓ [39m[90mDataDeps[39m
[32m  ✓ [39m[90mCSV[39m
[32m  ✓ [39mMLDatasets
  3 dependencies successfully precompiled in 9 seconds. 200 already precompiled.


In [49]:
function loadmnist(batchsize, train_split)
    # Load MNIST
    N = 60000
    dataset = MNIST(; split=:train)
    imgs = dataset.features[:, :, 1:N]
    labels_raw = dataset.targets[1:N]

    # Process images into (H, W, C, BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false))
end

loadmnist (generic function with 1 method)

In [41]:
lux_model = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3),
    Chain(Dense(256 => 128, relu), Dense(128 => 84, relu),Dense(84 => 10)))

Chain(
    layer_1 = Conv((5, 5), 1 => 6, relu),  [90m# 156 parameters[39m
    layer_2 = MaxPool((2, 2)),
    layer_3 = Conv((5, 5), 6 => 16, relu),  [90m# 2_416 parameters[39m
    layer_4 = MaxPool((2, 2)),
    layer_5 = FlattenLayer{Static.StaticInt{3}}(static(3)),
    layer_6 = Chain(
        layer_1 = Dense(256 => 128, relu),  [90m# 32_896 parameters[39m
        layer_2 = Dense(128 => 84, relu),  [90m# 10_836 parameters[39m
        layer_3 = Dense(84 => 10),      [90m# 850 parameters[39m
    ),
) [90m        # Total: [39m47_154 parameters,
[90m          #        plus [39m0 states.

In [32]:
lux_model2 = Chain(Conv((3, 3), 1 => 32, relu), Conv((3, 3), 32 => 64, relu), MaxPool((3, 3)),
    Dropout(0.5), FlattenLayer(3),
    Chain(Dense( 4096 => 250, relu), Dense(250 => 10)))

Chain(
    layer_1 = Conv((3, 3), 1 => 32, relu),  [90m# 320 parameters[39m
    layer_2 = Conv((3, 3), 32 => 64, relu),  [90m# 18_496 parameters[39m
    layer_3 = MaxPool((3, 3)),
    layer_4 = Dropout(0.5),
    layer_5 = FlattenLayer{Static.StaticInt{3}}(static(3)),
    layer_6 = Chain(
        layer_1 = Dense(4096 => 250, relu),  [90m# 1_024_250 parameters[39m
        layer_2 = Dense(250 => 10),     [90m# 2_510 parameters[39m
    ),
) [90m        # Total: [39m1_045_576 parameters,
[90m          #        plus [39m2 states.

In [5]:
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(Array(first(model(x, ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

accuracy (generic function with 1 method)

In [51]:
function train(model; rng=Xoshiro(0), kwargs...)
    train_dataloader, test_dataloader = loadmnist(128, 0.9)
    ps, st = Lux.setup(rng, model)

    train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))

    ### Warmup the model
    x_proto = randn(rng, Float32, 28, 28, 1, 1)
    y_proto = onehotbatch([1], 0:9)
    Training.compute_gradients(AutoEnzyme(), loss, (x_proto, y_proto), train_state)

    ### Lets train the model
    nepochs = 10
    tr_acc, te_acc = 0.0, 0.0
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            gs, _, _, train_state = Training.single_train_step!(
                AutoEnzyme(), loss, (x, y), train_state)
        end
        ttime = time() - stime

        tr_acc = accuracy(
            model, train_state.parameters, train_state.states, train_dataloader) * 100
        te_acc = accuracy(
            model, train_state.parameters, train_state.states, test_dataloader) * 100

        @printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
                 %.2f%%\n" epoch nepochs ttime tr_acc te_acc
    end

    return tr_acc, te_acc
end

train (generic function with 1 method)

In [47]:
tr_acc, te_acc = train(lux_model)

[ 1/50] 	 Time 0.99s 	 Training Accuracy: 23.28% 	 Test Accuracy: 18.00%
[ 2/50] 	 Time 0.18s 	 Training Accuracy: 47.17% 	 Test Accuracy: 43.50%
[ 3/50] 	 Time 0.18s 	 Training Accuracy: 60.28% 	 Test Accuracy: 59.00%
[ 4/50] 	 Time 0.19s 	 Training Accuracy: 69.50% 	 Test Accuracy: 65.00%
[ 5/50] 	 Time 0.17s 	 Training Accuracy: 74.89% 	 Test Accuracy: 73.00%
[ 6/50] 	 Time 0.17s 	 Training Accuracy: 77.78% 	 Test Accuracy: 76.50%
[ 7/50] 	 Time 0.17s 	 Training Accuracy: 80.89% 	 Test Accuracy: 78.00%
[ 8/50] 	 Time 0.17s 	 Training Accuracy: 83.39% 	 Test Accuracy: 82.50%
[ 9/50] 	 Time 0.27s 	 Training Accuracy: 84.89% 	 Test Accuracy: 84.00%
[10/50] 	 Time 0.18s 	 Training Accuracy: 86.67% 	 Test Accuracy: 85.50%
[11/50] 	 Time 0.17s 	 Training Accuracy: 87.94% 	 Test Accuracy: 86.00%
[12/50] 	 Time 0.18s 	 Training Accuracy: 88.67% 	 Test Accuracy: 88.00%
[13/50] 	 Time 0.18s 	 Training Accuracy: 90.06% 	 Test Accuracy: 88.00%
[14/50] 	 Time 0.18s 	 Training Accuracy: 90.17% 	 

(99.72222222222223, 93.5)

In [52]:
tr_acc2, te_acc2 = train(lux_model2)

[ 1/10] 	 Time 74.42s 	 Training Accuracy: 97.84% 	 Test Accuracy: 98.37%
[ 2/10] 	 Time 71.04s 	 Training Accuracy: 98.69% 	 Test Accuracy: 98.72%
[ 3/10] 	 Time 71.29s 	 Training Accuracy: 99.09% 	 Test Accuracy: 98.90%
[ 4/10] 	 Time 74.21s 	 Training Accuracy: 99.35% 	 Test Accuracy: 99.00%
[ 5/10] 	 Time 71.32s 	 Training Accuracy: 99.50% 	 Test Accuracy: 99.05%
[ 6/10] 	 Time 72.00s 	 Training Accuracy: 99.52% 	 Test Accuracy: 99.13%
[ 7/10] 	 Time 72.08s 	 Training Accuracy: 99.48% 	 Test Accuracy: 98.85%
[ 8/10] 	 Time 74.15s 	 Training Accuracy: 99.71% 	 Test Accuracy: 99.10%
[ 9/10] 	 Time 73.69s 	 Training Accuracy: 99.81% 	 Test Accuracy: 99.12%
[10/10] 	 Time 66.82s 	 Training Accuracy: 99.72% 	 Test Accuracy: 99.12%


(99.71666666666667, 99.11666666666666)

In [34]:
f(x,y) = 2*x + 3*y

f (generic function with 1 method)

In [36]:
@code_llvm f(2,3)

[90m;  @ In[34]:1 within `f`[39m
[95mdefine[39m [36mi64[39m [93m@julia_f_23589[39m[33m([39m[36mi64[39m [95msignext[39m [0m%0[0m, [36mi64[39m [95msignext[39m [0m%1[33m)[39m [0m#0 [33m{[39m
[91mtop:[39m
[90m; ┌ @ int.jl:88 within `*`[39m
   [0m%2 [0m= [96m[1mshl[22m[39m [36mi64[39m [0m%0[0m, [33m1[39m
   [0m%3 [0m= [96m[1mmul[22m[39m [36mi64[39m [0m%1[0m, [33m3[39m
[90m; └[39m
[90m; ┌ @ int.jl:87 within `+`[39m
   [0m%4 [0m= [96m[1madd[22m[39m [36mi64[39m [0m%3[0m, [0m%2
   [96m[1mret[22m[39m [36mi64[39m [0m%4
[90m; └[39m
[33m}[39m
