In [1]:
# Simple VAE (Lux) with training loop in user‑requested style
# --------------------------------------------------------------------
# Implements: `train!(model, ps, st, data; epochs=10)`
#   • Initializes `tstate = Training.TrainState(model, ps, st, Adam(...))`
#   • Uses `single_train_step!` inside the nested loops
#   • Collects `losses` (for visualization) and returns them.
# Dataset: MNIST (flattened 28×28), reconstruction target y = x.

using Lux, Lux.Training, Random, Optimisers, Zygote, MLDatasets, Functors

In [2]:
# ------------------------------
# Hyper‑parameters
# ------------------------------
latent_dim = 10
hidden_dim = 128
input_dim  = 28 * 28
batch_size = 64
epochs     = 5
learning_rate = 1e-2

0.01

In [3]:
# ------------------------------
# MNIST loader that yields (x, y = x) tuples
# ------------------------------
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"
train_x, _ = MLDatasets.MNIST.traindata()
train_x = reshape(Float32.(train_x) ./ 255f0, input_dim, :)

└ @ MLDatasets /Users/briandepasquale/.julia/packages/MLDatasets/0MkOE/src/datasets/vision/mnist.jl:187


784×60000 Matrix{Float32}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 ⋮                        ⋮              ⋱            ⋮                   
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0 

In [4]:
function batch_iter(X; batch=batch_size, rng=Random.default_rng())
    idx = collect(1:size(X, 2))
    Random.shuffle!(rng, idx)
    for chunk in Iterators.partition(idx, batch)
        xb = X[:, chunk]
        yield((xb, xb))            # (input, target) pair
    end
end

batch_iter (generic function with 1 method)

In [None]:
function encoder(
    rng=Random.default_rng();
    num_latent_dims::Int,
    image_shape::Dims{3},
    max_num_filters::Int,
)
    flattened_dim = prod(image_shape[1:2] .÷ 8) * max_num_filters
    return @compact(;
        embed=Dense(flattened_dim, flattened_dim; init_bias=zeros32),
        proj_mu=Dense(flattened_dim, num_latent_dims; init_bias=zeros32),
        proj_log_var=Dense(flattened_dim, num_latent_dims; init_bias=zeros32),
        rng
    ) do x
        y = embed(x)

        μ = proj_mu(y)
        logσ² = proj_log_var(y)

        T = eltype(logσ²)
        logσ² = clamp.(logσ², -T(20.0f0), T(10.0f0))
        σ = exp.(logσ² .* T(0.5))

        # Generate a tensor of random values from a normal distribution
        rng = Lux.replicate(rng)
        ϵ = randn_like(rng, σ)

        # Reparameterization trick to brackpropagate through sampling
        z = ϵ .* σ .+ μ

        @return z, μ, logσ²
    end
end

In [9]:
# Initialize parameters & state
rng = Random.default_rng()
ps, st = Lux.setup(rng, vae)

(SimpleVAE{@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, @NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, @NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, @NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, @NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}((weight = Float32[-0.12327705 -0.11697393 … -0.05656602 0.040849246; 0.015953189 0.08447786 … -0.016243627 -0.024371725; … ; 0.08097171 -0.029226342 … -6.241491f-5 -0.09555681; 0.08673438 -0.005603862 … -0.065215245 -0.043865424], bias = Float32[0.029358562, 0.019185463, -0.01677743, 0.019504147, -0.019974573, 0.0143271815, 0.015715718, 0.004030245, 0.027582126, 0.016351137  …  0.028981255, 0.03434905, 0.028943207, -0.031130774, 0.007312792, 0.026094684, 0.025853554, 0.009912525, -0.02758334, 0.025112638]), (weight = Float32[0.108728476 -0.056329127 … 0.0057340604 0.14007536; 0.14121042 0.14495571 … 0.12086126 0.13093837; … ; -0.047098126 0.017961422 … 0.14909995 0.08167165; 

In [20]:
# ------------------------------
# Loss helper (MSE + KL)
# ------------------------------
const mse_loss = MSELoss()

function full_loss(model, p, s, x, y, rng)
    (x̂, kl), _ = model(x, rng, p, s)
    return mse_loss(x̂, y) + kl
end

full_loss (generic function with 1 method)

In [21]:
opt = Adam(0.03f0)

Adam(eta=0.03, beta=(0.9, 0.999), epsilon=1.0e-8)

In [23]:
tstate = Training.TrainState(vae, ps, st, opt)

MethodError: MethodError: no method matching Lux.Training.TrainState(::SimpleVAE{Dense{typeof(relu), Int64, Int64, Nothing, Nothing, Static.True}, Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, Dense{typeof(relu), Int64, Int64, Nothing, Nothing, Static.True}, Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, ::SimpleVAE{@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, @NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, @NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, @NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, @NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, ::SimpleVAE{@NamedTuple{}, @NamedTuple{}, @NamedTuple{}, @NamedTuple{}, @NamedTuple{}}, ::Adam{Float32, Tuple{Float64, Float64}, Float64})

Closest candidates are:
  Lux.Training.TrainState(::__T_cache, ::__T_objective_function, ::__T_model, ::__T_parameters, !Matched::__T_states, !Matched::__T_optimizer, !Matched::__T_optimizer_state, !Matched::Int64) where {__T_cache, __T_objective_function, __T_model, __T_parameters, __T_states, __T_optimizer, __T_optimizer_state}
   @ Lux ~/.julia/packages/ConcreteStructs/7Lv7u/src/ConcreteStructs.jl:142
  Lux.Training.TrainState(!Matched::AbstractLuxLayer, ::Any, ::Any, ::AbstractRule)
   @ Lux ~/.julia/packages/Lux/L2VO7/src/helpers/training.jl:65


In [17]:
# ------------------------------
# User‑style training loop
# ------------------------------
function train!(model, ps, st, data; epochs=10, lr=1e-2, rng=Random.default_rng())
    losses = []
    tstate = Lux.Training.TrainState((model, ps, st), Optimisers.Adam(lr))
    for _ in 1:epochs
        for (x, y) in data()
            _, loss, _, tstate = Training.single_train_step!(AutoZygote(), loss_closure, (x,y), tstate)
            push!(losses, loss)
        end
    end
    return losses, tstate.parameters, tstate.states
end

train! (generic function with 1 method)

In [18]:
# ------------------------------
# Run training
losses, ps, st = train!(vae, ps, st, data; epochs=epochs, lr=learning_rate, rng=rng)
@info "Done. Final batch loss ≈ $(round(losses[end], sigdigits=5))."


MethodError: MethodError: no method matching Lux.Training.TrainState(::Tuple{SimpleVAE{Dense{typeof(relu), Int64, Int64, Nothing, Nothing, Static.True}, Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, Dense{typeof(relu), Int64, Int64, Nothing, Nothing, Static.True}, Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, SimpleVAE{@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, @NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, @NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, @NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, @NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, SimpleVAE{@NamedTuple{}, @NamedTuple{}, @NamedTuple{}, @NamedTuple{}, @NamedTuple{}}}, ::Adam{Float64, Tuple{Float64, Float64}, Float64})

Closest candidates are:
  Lux.Training.TrainState(::__T_cache, ::__T_objective_function, !Matched::__T_model, !Matched::__T_parameters, !Matched::__T_states, !Matched::__T_optimizer, !Matched::__T_optimizer_state, !Matched::Int64) where {__T_cache, __T_objective_function, __T_model, __T_parameters, __T_states, __T_optimizer, __T_optimizer_state}
   @ Lux ~/.julia/packages/ConcreteStructs/7Lv7u/src/ConcreteStructs.jl:142
  Lux.Training.TrainState(!Matched::AbstractLuxLayer, ::Any, !Matched::Any, !Matched::AbstractRule)
   @ Lux ~/.julia/packages/Lux/L2VO7/src/helpers/training.jl:65
