Regularization is a common technique in machine learning used to improve model performance and generalization. It involves adding a penalty term to the loss function, typically to prevent overfitting. Overfitting occurs when a model captures patterns specific to the training data but fails to generalize well to unseen data. Regularization can also help control exploding weights or enforce sparsity, which can lead to more interpretable models. In this notebook, we will demonstrate how to incorporate regularization into model training using `Lux.jl`.

`Lux.jl` is a machine learning library built entirely in pure Julia. It is designed for simplicity, flexibility, and high performance. One of its defining features is the strict separation between model parameters and layer structures. This approach is somewhat different from other machine learning libraries. But it provides deeper insight into the inner workings of model training and architecture.

# Setup

Let us first install `Lux.jl` and other required packages ...

In [None]:
import Pkg
Pkg.activate(".")
Pkg.add(["Lux", "Random", "Printf", "Enzyme", "Optimisers"])

... and load them into the notebook.

In [2]:
using Lux, Random, Printf, Enzyme, Optimisers

Besides `Lux` we need 
- `Random` for generating random numbers
- `Printf` for formatted printing
- `Enzyme` for automatic differentiation of the models and loss functions
- `Optimisers` for the ADAM optimiser

Lets define a simple multi-layer perceptron (MLP) model with a single hidden layer and $\tanh$ activation function.

In [3]:
model = Chain(
    Dense(2, 4, tanh),
    Dense(4, 1),
)

Chain(
    layer_1 = Dense(2 => 4, tanh),      [90m# 12 parameters[39m
    layer_2 = Dense(4 => 1),            [90m# 5 parameters[39m
) [90m        # Total: [39m17 parameters,
[90m          #        plus [39m0 states.

Besides the separation of model parameters and layer structure, `Lux` takes randomness very seriously, too. To align with this design philosophy, we make this notebook reproducible by fixing the seed of the random number generator and generating dummy training data.

In [4]:
rng = Random.default_rng()
Random.seed!(rng, 42)

X_train = randn(rng, Float32, 2,100)
y_train = randn(rng, Float32, 1,100)

1×100 Matrix{Float32}:
 -1.48774  -1.55746  -0.013772  0.423002  …  -0.749528  -0.293559  -1.15156

As mentioned, `Lux` keeps model parameters and layer structures separate. The parameters are set up using a `setup` method, which returns the parameters `ps` and the layer states `st`. In our simple MLP example, the layers don’t have any states. But for layers like batch normalization, you’d see states being used.

In [5]:
ps, st = LuxCore.setup(rng, model)

((layer_1 = (weight = Float32[0.353489 -1.8094461; 1.5090263 -1.1654156; -0.83640474 -1.1235346; 1.8991852 1.449335], bias = Float32[0.41897145, -0.069400184, -0.6758513, 0.67588615]), layer_2 = (weight = Float32[0.44184172 0.014959536 0.5158945 0.3650591], bias = Float32[0.24236047])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

To evaluate the model, we simply call it like a regular function, passing the model parameters (and the empty states) as function arguments.

In [6]:
model(X_train[:,1], ps, st)

(Float32[0.7788858], (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

# Regularisation

There are multiple ways to regularize a model. Some common techniques include $l_1$ and $l_2$ regularization. In this notebook, we will focus on L2 regularization, which is also known as weight decay. It involves adding a penalty term to the loss function that is proportional to the square of the model parameters (weights).

On my first attempt to implement $l_2$ regularisation of model parameters I wrote the following custom loss function:

```{julia}
function loss_function(model, ps, st, (x, y))
    T = eltype(Base.Flatten(first(ps)))

    loss_mse = MSELoss()(model, ps, st, (x,y))[1]

    loss_reg = zero(T)
    for p in ps
        loss_reg += sum(abs2, Base.Flatten(p))
    end

    loss_total = loss_mse + convert(T, 0.001) * loss_reg

    return loss_total, st, NamedTuple()
end
```

It works fine but isn't the most extendable solution. The Julia community on [Discourse](https://discourse.julialang.org/t/custom-loss-functions-in-lux-jl/125661) suggested using the `WeightDecay` function from `Optimisers.jl`, which does exactly what we need. By chaining it with the `MSELoss` function, we can train our dummy MLP model with $l_2$ regularization like this:

In [7]:
# code adapted from the Lux documentation https://lux.csail.mit.edu/stable/
function train_model!(model, ps, st, x, y)

    train_state = Lux.Training.TrainState(model, ps, st,
        # here we chain together the optimiser Adam with 
        # a WeightDecay of 0.001. 
        OptimiserChain(Adam(0.01f0), WeightDecay(0.001)),
    )

    for iter in 1:1000
        _, loss, _, train_state = Lux.Training.single_train_step!(
            AutoEnzyme(),
            MSELoss(),
            (x, y), train_state
        )
        if iter % 100 == 1 || iter == 1000
            @printf "Iteration: %04d \t Loss: %10.9g\n" iter loss
        end
    end

    return model, ps, st
end

train_model!(model, ps, st, X_train, y_train)

Iteration: 0001 	 Loss: 1.24611223
Iteration: 0101 	 Loss: 0.991424084
Iteration: 0201 	 Loss: 0.959512353
Iteration: 0301 	 Loss: 0.949317396
Iteration: 0401 	 Loss: 0.941184938
Iteration: 0501 	 Loss: 0.932920218
Iteration: 0601 	 Loss: 0.922357559
Iteration: 0701 	 Loss: 0.910138071
Iteration: 0801 	 Loss: 0.898681343
Iteration: 0901 	 Loss: 0.888456821
Iteration: 1000 	 Loss: 0.879557967


(Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 4, tanh), layer_2 = Dense(4 => 1)), nothing), (layer_1 = (weight = Float32[5.620661 -4.1725636; 3.3105361 -3.3534684; -1.4295509 -0.7033358; 3.2561698 1.6087654], bias = Float32[0.7682078, 0.48789456, -0.012730246, -0.3011559]), layer_2 = (weight = Float32[1.2095982 -1.1726627 1.2155563 0.9848269], bias = Float32[0.29594985])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

That's it! Just chain together the optimizer `Adam` with `WeightDecay` via `OptimiserChain` and we have $l_2$ regularisation of model weights.