In [1]:
import Pkg
Pkg.activate(".")

[32m[1m  Activating[22m[39m project at `~/Documents/programming/SigmaCampNext-2025`


In [2]:
Pkg.add(["Lux", "ADTypes", "MLUtils", "Optimisers", "Zygote", "OneHotArrays", "Random", "Statistics", "Printf", "MLDatasets"])

[32m[1m   Resolving[22m[39m package versions...
[32m[1m    Updating[22m[39m `~/Documents/programming/SigmaCampNext-2025/Project.toml`
  [90m[47edcb42] [39m[92m+ ADTypes v1.16.0[39m
  [90m[b2108857] [39m[92m+ Lux v1.16.0[39m
  [90m[f1d291b0] [39m[92m+ MLUtils v0.4.8[39m
  [90m[0b1bfda6] [39m[92m+ OneHotArrays v0.2.10[39m
  [90m[3bd65402] [39m[92m+ Optimisers v0.4.6[39m
  [90m[10745b16] [39m[92m+ Statistics v1.11.1[39m
  [90m[e88e6eb3] [39m[92m+ Zygote v0.7.10[39m
  [90m[de0858da] [39m[92m+ Printf v1.11.0[39m
[32m[1m    Updating[22m[39m `~/Documents/programming/SigmaCampNext-2025/Manifest.toml`
  [90m[47edcb42] [39m[92m+ ADTypes v1.16.0[39m
  [90m[082447d4] [39m[92m+ ChainRules v1.72.5[39m
  [90m[bbf7d656] [39m[92m+ CommonSubexpressions v0.3.1[39m
  [90m[2569d6c7] [39m[92m+ ConcreteStructs v0.2.3[39m
  [90m[163ba53b] [39m[92m+ DiffResults v1.1.0[39m
  [90m[b552c78f] [39m[92m+ DiffRules v1.15.1[39m
  [90m[8d63f2c5] [39m

In [3]:
using Lux, ADTypes, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf
using MLDatasets: MNIST

In [4]:
function loadmnist(batchsize, train_split)
    # Load MNIST
    N = 60000
    dataset = MNIST(; split=:train)
    imgs = dataset.features[:, :, 1:N]
    labels_raw = dataset.targets[1:N]

    # Process images into (H, W, C, BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false))
end

loadmnist (generic function with 1 method)

In [5]:
lux_model = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3),
    Chain(Dense(256 => 128, relu), Dense(128 => 84, relu),Dense(84 => 10)))

Chain(
    layer_1 = Conv((5, 5), 1 => 6, relu),  [90m# 156 parameters[39m
    layer_2 = MaxPool((2, 2)),
    layer_3 = Conv((5, 5), 6 => 16, relu),  [90m# 2_416 parameters[39m
    layer_4 = MaxPool((2, 2)),
    layer_5 = FlattenLayer{Static.StaticInt{3}}(static(3)),
    layer_6 = Chain(
        layer_1 = Dense(256 => 128, relu),  [90m# 32_896 parameters[39m
        layer_2 = Dense(128 => 84, relu),  [90m# 10_836 parameters[39m
        layer_3 = Dense(84 => 10),      [90m# 850 parameters[39m
    ),
) [90m        # Total: [39m47_154 parameters,
[90m          #        plus [39m0 states.

In [6]:
lux_model2 = Chain(Conv((3, 3), 1 => 32, relu), Conv((3, 3), 32 => 64, relu), MaxPool((3, 3)),
    Dropout(0.5), FlattenLayer(3),
    Chain(Dense( 4096 => 250, relu), Dense(250 => 10)))

Chain(
    layer_1 = Conv((3, 3), 1 => 32, relu),  [90m# 320 parameters[39m
    layer_2 = Conv((3, 3), 32 => 64, relu),  [90m# 18_496 parameters[39m
    layer_3 = MaxPool((3, 3)),
    layer_4 = Dropout(0.5),
    layer_5 = FlattenLayer{Static.StaticInt{3}}(static(3)),
    layer_6 = Chain(
        layer_1 = Dense(4096 => 250, relu),  [90m# 1_024_250 parameters[39m
        layer_2 = Dense(250 => 10),     [90m# 2_510 parameters[39m
    ),
) [90m        # Total: [39m1_045_576 parameters,
[90m          #        plus [39m2 states.

In [7]:
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(Array(first(model(x, ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

accuracy (generic function with 1 method)

In [8]:
function train(model; rng=Xoshiro(0), kwargs...)
    train_dataloader, test_dataloader = loadmnist(128, 0.9)
    ps, st = Lux.setup(rng, model)

    train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))

    ### Warmup the model
    x_proto = randn(rng, Float32, 28, 28, 1, 1)
    y_proto = onehotbatch([1], 0:9)
    Training.compute_gradients(AutoZygote(), loss, (x_proto, y_proto), train_state)

    ### Lets train the model
    nepochs = 10
    tr_acc, te_acc = 0.0, 0.0
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            gs, _, _, train_state = Training.single_train_step!(
                AutoZygote(), loss, (x, y), train_state)
        end
        ttime = time() - stime

        tr_acc = accuracy(
            model, train_state.parameters, train_state.states, train_dataloader) * 100
        te_acc = accuracy(
            model, train_state.parameters, train_state.states, test_dataloader) * 100

        @printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
                 %.2f%%\n" epoch nepochs ttime tr_acc te_acc
    end

    return tr_acc, te_acc
end

train (generic function with 1 method)

In [9]:
tr_acc, te_acc = train(lux_model)

This program has requested access to the data dependency MNIST.
which is not currently installed. It can be installed automatically, and you will not see this message again.

Dataset: THE MNIST DATABASE of handwritten digits
Authors: Yann LeCun, Corinna Cortes, Christopher J.C. Burges
Website: http://yann.lecun.com/exdb/mnist/

[LeCun et al., 1998a]
    Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner.
    "Gradient-based learning applied to document recognition."
    Proceedings of the IEEE, 86(11):2278-2324, November 1998

The files are available for download at the offical
website linked above. Note that using the data
responsibly and respecting copyright remains your
responsibility. The authors of MNIST aren't really
explicit about any terms of use, so please read the
website to make sure you want to download the
dataset.



Do you want to download the dataset from ["https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz", "https://ossci-datasets.s3.amazonaws.com/mn

stdin>  y


[ 1/10] 	 Time 7.04s 	 Training Accuracy: 91.63% 	 Test Accuracy: 93.53%
[ 2/10] 	 Time 5.11s 	 Training Accuracy: 94.98% 	 Test Accuracy: 95.65%
[ 3/10] 	 Time 5.13s 	 Training Accuracy: 96.50% 	 Test Accuracy: 96.93%
[ 4/10] 	 Time 5.12s 	 Training Accuracy: 97.34% 	 Test Accuracy: 97.38%
[ 5/10] 	 Time 5.09s 	 Training Accuracy: 97.40% 	 Test Accuracy: 97.45%
[ 6/10] 	 Time 5.06s 	 Training Accuracy: 98.01% 	 Test Accuracy: 97.72%
[ 7/10] 	 Time 5.06s 	 Training Accuracy: 98.24% 	 Test Accuracy: 97.78%
[ 8/10] 	 Time 5.05s 	 Training Accuracy: 98.56% 	 Test Accuracy: 98.13%
[ 9/10] 	 Time 5.06s 	 Training Accuracy: 98.54% 	 Test Accuracy: 98.08%
[10/10] 	 Time 5.13s 	 Training Accuracy: 98.61% 	 Test Accuracy: 98.00%


(98.61111111111111, 98.0)

In [10]:
tr_acc2, te_acc2 = train(lux_model2)

[ 1/10] 	 Time 70.71s 	 Training Accuracy: 97.97% 	 Test Accuracy: 98.33%
[ 2/10] 	 Time 69.48s 	 Training Accuracy: 98.70% 	 Test Accuracy: 98.68%
[ 3/10] 	 Time 69.18s 	 Training Accuracy: 99.11% 	 Test Accuracy: 98.90%
[ 4/10] 	 Time 68.65s 	 Training Accuracy: 99.23% 	 Test Accuracy: 98.85%
[ 5/10] 	 Time 70.56s 	 Training Accuracy: 99.55% 	 Test Accuracy: 99.08%
[ 6/10] 	 Time 79.19s 	 Training Accuracy: 99.54% 	 Test Accuracy: 99.02%
[ 7/10] 	 Time 69.50s 	 Training Accuracy: 99.68% 	 Test Accuracy: 99.15%
[ 8/10] 	 Time 69.14s 	 Training Accuracy: 99.73% 	 Test Accuracy: 99.22%
[ 9/10] 	 Time 73.28s 	 Training Accuracy: 99.72% 	 Test Accuracy: 99.18%
[10/10] 	 Time 69.17s 	 Training Accuracy: 99.85% 	 Test Accuracy: 99.28%


(99.8537037037037, 99.28333333333333)

In [34]:
f(x,y) = 2*x + 3*y

f (generic function with 1 method)

In [36]:
@code_llvm f(2,3)

[90m;  @ In[34]:1 within `f`[39m
[95mdefine[39m [36mi64[39m [93m@julia_f_23589[39m[33m([39m[36mi64[39m [95msignext[39m [0m%0[0m, [36mi64[39m [95msignext[39m [0m%1[33m)[39m [0m#0 [33m{[39m
[91mtop:[39m
[90m; ┌ @ int.jl:88 within `*`[39m
   [0m%2 [0m= [96m[1mshl[22m[39m [36mi64[39m [0m%0[0m, [33m1[39m
   [0m%3 [0m= [96m[1mmul[22m[39m [36mi64[39m [0m%1[0m, [33m3[39m
[90m; └[39m
[90m; ┌ @ int.jl:87 within `+`[39m
   [0m%4 [0m= [96m[1madd[22m[39m [36mi64[39m [0m%3[0m, [0m%2
   [96m[1mret[22m[39m [36mi64[39m [0m%4
[90m; └[39m
[33m}[39m
