In [1]:
using Pkg
Pkg.status()

[32m[1mStatus[22m[39m `~/git/Julia_ML_training/unit5/Project.toml`
[32m⌃[39m [90m[cbdf2221] [39mAlgebraOfGraphics v0.10.7
  [90m[336ed68f] [39mCSV v0.10.15
[32m⌃[39m [90m[13f3f980] [39mCairoMakie v0.13.9
  [90m[a93c6f00] [39mDataFrames v1.7.0
  [90m[31c24e10] [39mDistributions v0.25.120
[32m⌃[39m [90m[587475ba] [39mFlux v0.16.3
  [90m[38e38edf] [39mGLM v1.9.0
  [90m[60bf3e95] [39mGLPK v1.2.1
  [90m[09f84164] [39mHypothesisTests v0.11.5
  [90m[4076af6c] [39mJuMP v1.26.0
  [90m[23fbe1c1] [39mLatexify v0.16.8
[32m⌃[39m [90m[b2108857] [39mLux v1.13.0
  [90m[eb30cadb] [39mMLDatasets v0.7.18
  [90m[add582a8] [39mMLJ v0.20.8
  [90m[f1d291b0] [39mMLUtils v0.4.8
[33m⌅[39m [90m[ee78f7c6] [39mMakie v0.22.9
  [90m[ff71e718] [39mMixedModels v4.35.2
  [90m[6f286f6a] [39mMultivariateStats v0.10.3
  [90m[636a865e] [39mNearestNeighborModels v0.2.3
[32m⌃[39m [90m[429524aa] [39mOptim v1.12.0
  [90m[3bd65402] [39mOptimisers v0.4.6
  [90m[92933f4c]

In [2]:
using Lux
using Zygote # The primary automatic differentiation engine
using Optimisers # The new standard for optimizers

In [3]:
using MLDatasets
using MLUtils # Provides the DataLoader
using Statistics
using Random
using ProgressMeter
using Flux: onehotbatch, onecold, logitcrossentropy # Import the needed utilities from Flux

In [4]:
ProgressMeter.ijulia_behavior(:append);

In [5]:
const rng = Random.default_rng()
Random.seed!(rng, 0);

In [6]:
# Hyperparameters
learning_rate = 0.001
batch_size = 128
num_epochs = 10;

In [7]:
train_x_raw, train_y_raw = MNIST(split=:train)[:]
test_x_raw,  test_y_raw  = MNIST(split=:test)[:];

In [8]:
# Preprocessing is identical
function preprocess_features(x)
    return Float32.(reshape(x, 28*28, :))
end

train_x = preprocess_features(train_x_raw)
test_x = preprocess_features(test_x_raw)
train_y = onehotbatch(train_y_raw, 0:9)
test_y = onehotbatch(test_y_raw, 0:9);

In [9]:
# Use DataLoader from MLUtils
train_loader = DataLoader((train_x, train_y), batchsize=batch_size, shuffle=true);

In [10]:
# --- Model Definition and Initialization (Key Lux difference) ---

# Define the model structure. It looks similar to Flux's Chain.
# Note the `=>` syntax for Dense layers.
model = Chain(
      Dense(28*28 => 128, relu),  # Input: 784 -> Hidden 1: 128 neurons
      Dense(128 => 64, relu),     # Hidden 1: 128 -> Hidden 2: 64 neurons
      Dense(64 => 10)              # Hidden 2: 64 -> Output: 10 logits
)

Chain(
    layer_1 = Dense(784 => 128, relu),  [90m# 100_480 parameters[39m
    layer_2 = Dense(128 => 64, relu),   [90m# 8_256 parameters[39m
    layer_3 = Dense(64 => 10),          [90m# 650 parameters[39m
) [90m        # Total: [39m109_386 parameters,
[90m          #        plus [39m0 states.

In [11]:
# In Lux, we explicitly initialize the parameters (ps) and state (st).
# This is a core concept: the model structure is separate from its learnable parts.
ps, st = Lux.setup(rng, model);

In [12]:
# You can inspect the structure of the parameters:
Lux.display(ps)

(layer_1 = (weight = Float32[-0.10921513 0.056196645 … 0.110608466 0.057007626; -0.011075623 -0.112844504 … -0.052739404 0.046212457; … ; 0.040891483 0.018354304 … 0.087309256 0.091325976; -0.011361991 -0.09112135 … -0.122030705 -0.11557871], bias = Float32[0.0125793945, -0.024275485, 0.03424715, 8.1130434f-5, -0.02638943, -0.034375455, -0.010637262, 0.025727345, 0.016990406, 0.011116858  …  -0.031849086, -0.021829436, 0.0001961163, 0.022090366, -0.025813958, -0.007818311, 0.018396122, -0.012626337, 0.01134666, 0.0288197]), layer_2 = (weight = Float32[0.2037279 0.121231 … 0.2081713 -0.26592368; 0.10215899 0.30275932 … -0.13877812 0.28471065; … ; -0.030344477 0.09217859 … 0.0432775 -0.19036081; -0.1642193 0.20556633 … 0.06377395 -0.094860405], bias = Float32[-0.06744722, -0.010483333, 0.08303247, -0.07705071, 0.077265754, -0.00868912, -0.04639269, 0.035928313, 0.0037759468, 0.027657595  …  0.01394165, -0.025066027, 0.06153464, 0.0838878, 0.07016804, 0.005167151, -0.020277672, -0.0531549

In [13]:
# --- Loss Function and Optimizer Setup ---

# The loss function must take the model, parameters, state, and data as arguments.
# It returns the loss value and the (potentially updated) state.
function loss_function(model, ps, st, x, y)
    y_pred, st_new = model(x, ps, st)
    loss = logitcrossentropy(y_pred, y)
    return loss, st_new
end

loss_function (generic function with 1 method)

In [14]:
# Set up the optimizer using Optimisers.jl
# It takes the optimization rule (Adam) and the parameters to create the optimizer state.
opt_state = Optimisers.setup(Adam(learning_rate), ps)

# --- 5Training Loop (Explicitly written) ---
println("\nStarting training...")

for epoch in 1:num_epochs
    p = Progress(length(train_loader); desc="Epoch $epoch: ")
    
    for (x_batch, y_batch) in train_loader
        # Zygote.withgradient calculates the loss, the new state, and the gradients all at once.
        (loss, st), grads = Zygote.withgradient(ps -> loss_function(model, ps, st, x_batch, y_batch), ps)
        
        # Update the optimizer state and the parameters
        opt_state, ps = Optimisers.update!(opt_state, ps, grads[1])
        
        next!(p; showvalues=[(:loss, loss)]) # Update progress bar
    end

    # Evaluation step after each epoch
    # In evaluation mode, we don't need gradients.
    # The model call still requires `ps` and `st`.
    y_hat_logits, _ = model(test_x, ps, st)
    y_hat_labels = onecold(y_hat_logits) .- 1 # onecold is 1-based, labels are 0-based
    
    current_accuracy = mean(y_hat_labels .== test_y_raw)
    println("\nEpoch $epoch: Test Accuracy = ", round(current_accuracy * 100, digits=2), "%")
end

println("\nTraining complete!")


Starting training...


[32mEpoch 1:   0%|▏                                         |  ETA: 0:27:44[39m
[A4m   loss: 2.4095874[39m
[32mEpoch 1: 100%|██████████████████████████████████████████| Time: 0:00:07[39m
[34m   loss: 0.13186748[39m



Epoch 1: Test Accuracy = 95.33%


[32mEpoch 2:  19%|███████▉                                  |  ETA: 0:00:00[39m
[A4m   loss: 0.07697785[39m
[32mEpoch 2:  40%|████████████████▊                         |  ETA: 0:00:00[39m
[A4m   loss: 0.13621089[39m
[32mEpoch 2:  62%|██████████████████████████                |  ETA: 0:00:00[39m
[A4m   loss: 0.124958545[39m
[32mEpoch 2:  84%|███████████████████████████████████▎      |  ETA: 0:00:00[39m
[A4m   loss: 0.0916301[39m
[32mEpoch 2: 100%|██████████████████████████████████████████| Time: 0:00:00[39m
[34m   loss: 0.11521212[39m



Epoch 2: Test Accuracy = 96.64%


[32mEpoch 3:  21%|████████▋                                 |  ETA: 0:00:00[39m
[A4m   loss: 0.07287196[39m
[32mEpoch 3:  42%|█████████████████▉                        |  ETA: 0:00:00[39m
[A4m   loss: 0.04878325[39m
[32mEpoch 3:  64%|██████████████████████████▉               |  ETA: 0:00:00[39m
[A4m   loss: 0.08716352[39m
[32mEpoch 3:  85%|███████████████████████████████████▉      |  ETA: 0:00:00[39m
[A4m   loss: 0.13799715[39m
[32mEpoch 3: 100%|██████████████████████████████████████████| Time: 0:00:00[39m
[34m   loss: 0.0926114[39m



Epoch 3: Test Accuracy = 97.18%


[32mEpoch 4:  22%|█████████▎                                |  ETA: 0:00:00[39m
[A4m   loss: 0.14395942[39m
[32mEpoch 4:  43%|██████████████████▏                       |  ETA: 0:00:00[39m
[A4m   loss: 0.054061525[39m
[32mEpoch 4:  64%|███████████████████████████               |  ETA: 0:00:00[39m
[A4m   loss: 0.08021752[39m
[32mEpoch 4:  85%|███████████████████████████████████▊      |  ETA: 0:00:00[39m
[A4m   loss: 0.055396523[39m
[32mEpoch 4: 100%|██████████████████████████████████████████| Time: 0:00:00[39m
[34m   loss: 0.044266265[39m



Epoch 4: Test Accuracy = 97.42%


[32mEpoch 5:  21%|████████▋                                 |  ETA: 0:00:00[39m
[A4m   loss: 0.03425047[39m
[32mEpoch 5:  41%|█████████████████▍                        |  ETA: 0:00:00[39m
[A4m   loss: 0.025607392[39m
[32mEpoch 5:  62%|██████████████████████████                |  ETA: 0:00:00[39m
[A4m   loss: 0.06354272[39m
[32mEpoch 5:  83%|██████████████████████████████████▉       |  ETA: 0:00:00[39m
[A4m   loss: 0.014235704[39m
[32mEpoch 5: 100%|██████████████████████████████████████████| Time: 0:00:00[39m
[34m   loss: 0.028210139[39m



Epoch 5: Test Accuracy = 97.69%


[32mEpoch 6:  22%|█████████▏                                |  ETA: 0:00:00[39m
[A4m   loss: 0.035971403[39m
[32mEpoch 6:  43%|██████████████████▏                       |  ETA: 0:00:00[39m
[A4m   loss: 0.03638174[39m
[32mEpoch 6:  65%|███████████████████████████▌              |  ETA: 0:00:00[39m
[A4m   loss: 0.037897542[39m
[32mEpoch 6:  87%|████████████████████████████████████▋     |  ETA: 0:00:00[39m
[A4m   loss: 0.037345223[39m
[32mEpoch 6: 100%|██████████████████████████████████████████| Time: 0:00:00[39m
[34m   loss: 0.057961807[39m



Epoch 6: Test Accuracy = 97.56%


[32mEpoch 7:  17%|███████                                   |  ETA: 0:00:01[39m
[A4m   loss: 0.03432457[39m
[32mEpoch 7:  38%|███████████████▉                          |  ETA: 0:00:00[39m
[A4m   loss: 0.014777028[39m
[32mEpoch 7:  59%|████████████████████████▊                 |  ETA: 0:00:00[39m
[A4m   loss: 0.019553546[39m
[32mEpoch 7:  80%|█████████████████████████████████▋        |  ETA: 0:00:00[39m
[A4m   loss: 0.042672932[39m
[32mEpoch 7: 100%|██████████████████████████████████████████| Time: 0:00:00[39m
[34m   loss: 0.026678463[39m



Epoch 7: Test Accuracy = 97.72%


[32mEpoch 8:  19%|████████▏                                 |  ETA: 0:00:00[39m
[A4m   loss: 0.0077943364[39m
[32mEpoch 8:  39%|████████████████▎                         |  ETA: 0:00:00[39m
[A4m   loss: 0.011481822[39m
[32mEpoch 8:  60%|█████████████████████████▍                |  ETA: 0:00:00[39m
[A4m   loss: 0.021726256[39m
[32mEpoch 8:  82%|██████████████████████████████████▎       |  ETA: 0:00:00[39m
[A4m   loss: 0.050892252[39m
[32mEpoch 8: 100%|██████████████████████████████████████████| Time: 0:00:00[39m
[34m   loss: 0.032553535[39m



Epoch 8: Test Accuracy = 97.54%


[32mEpoch 9:  20%|████████▌                                 |  ETA: 0:00:00[39m
[A4m   loss: 0.014968177[39m
[32mEpoch 9:  41%|█████████████████                         |  ETA: 0:00:00[39m
[A4m   loss: 0.017752368[39m
[32mEpoch 9:  62%|█████████████████████████▉                |  ETA: 0:00:00[39m
[A4m   loss: 0.014934659[39m
[32mEpoch 9:  83%|██████████████████████████████████▊       |  ETA: 0:00:00[39m
[A4m   loss: 0.009883834[39m
[32mEpoch 9: 100%|██████████████████████████████████████████| Time: 0:00:00[39m
[34m   loss: 0.016362146[39m



Epoch 9: Test Accuracy = 97.84%


[32mEpoch 10:  21%|████████▋                                |  ETA: 0:00:00[39m
[A4m   loss: 0.012205366[39m
[32mEpoch 10:  42%|█████████████████▎                       |  ETA: 0:00:00[39m
[A4m   loss: 0.0042203614[39m
[32mEpoch 10:  65%|██████████████████████████▋              |  ETA: 0:00:00[39m
[A4m   loss: 0.048551943[39m
[32mEpoch 10:  88%|███████████████████████████████████▉     |  ETA: 0:00:00[39m
[A4m   loss: 0.0048665786[39m
[32mEpoch 10: 100%|█████████████████████████████████████████| Time: 0:00:00[39m
[34m   loss: 0.033069585[39m



Epoch 10: Test Accuracy = 97.39%

Training complete!


In [15]:
# --- Final Evaluation ---
println("\nEvaluating final model on the test set...")
final_logits, _ = model(test_x, ps, st)
final_predictions = onecold(final_logits) .- 1
final_accuracy = mean(final_predictions .== test_y_raw)

println("-------------------------------------------")
println("Final Test Accuracy: ", round(final_accuracy * 100, digits=2), "%")
println("-------------------------------------------")


Evaluating final model on the test set...
-------------------------------------------
Final Test Accuracy: 97.39%
-------------------------------------------
