In [None]:
#r "nuget: TorchSharp-cpu"

open TorchSharp
open type TorchSharp.torch
open type TorchSharp.TensorExtensionMethods
open type TorchSharp.torch.distributions

open Microsoft.DotNet.Interactive.Formatting
Formatter.SetPreferredMimeTypeFor(typeof<torch.Tensor>, "text/plain")
Formatter.Register<torch.Tensor>(fun (x:torch.Tensor) -> x.ToString(true))

# Training with a Learning Rate Scheduler

In Tutorial 6, we saw how the optimizers took an argument called the 'learning rate,' but didn't spend much time on it except to say that it could have a great impact on how quickly training would converge toward a solution. In fact, you can choose the learning rate (LR) so poorly, that the training doesn't converge at all.

If the LR is too small, training will go very slowly, wasting compute resources. If it is too large, training could result in numeric overflow, or NaNs. Either way, you're in trouble.

To further complicate matters, it turns out that the learning rate shouldn't necessarily be constant. Training can go much better if the learning rate starts out relatively large and gets smaller as you get closer to the end.

There's a solution for this, called a Learning Rate Scheduler. An LRS instance has access to the internal state of the optimizer, and can modify the LR as it goes along. There are several algorithms for scheduling, but TorchSharp only implements the two most conceptually simple: StepLR and ExponentialLR.

Before demonstrating, let's have a model and a baseline training loop.

In [None]:
type Trivial() as this = 
    inherit nn.Module("Trivial")

    let lin1 = nn.Linear(1000L, 100L)
    let lin2 = nn.Linear(100L, 10L)

    do
        this.RegisterComponents()

    override _.forward(input) = 
    
        use x = lin1.forward(input)
        use y = nn.functional.relu(x)
        lin2.forward(y)

In [None]:
let learning_rate = 0.01
let model = Trivial()

let dataBatch = rand(32,1000)  // Our pretend input data
let resultBatch = rand(32,10)  // Our pretend ground truth.

let loss x y = nn.functional.mse_loss().Invoke(x,y)

let optimizer = torch.optim.SGD(model.parameters(), learning_rate)

for epoch = 1 to 100 do
    // Compute the loss
    let pred = model.forward(dataBatch)
    let output = loss pred resultBatch

    // Clear the gradients before doing the back-propagation
    model.zero_grad()

    // Do back-progatation, which computes all the gradients.
    output.backward()

    optimizer.step() |> ignore

let pred = model.forward(dataBatch)
(loss pred resultBatch).item<single>()

When I ran this, the loss was down to 0.068 after 3 seconds. (It took longer the first time around.)

## StepLR

StepLR uses subtraction to adjust the learning rate every so often. The difference it makes to the training loop is that you wrap the optimizer, and then call `step` on the scheduler instead of the optimizer.

In [None]:
let learning_rate = 0.01
let model = Trivial()

let dataBatch = rand(32,1000)  // Our pretend input data
let resultBatch = rand(32,10)  // Our pretend ground truth.

let loss x y = nn.functional.mse_loss().Invoke(x,y)

let optimizer = torch.optim.SGD(model.parameters(), learning_rate)
let scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1u, 0.99)    // Pass in 'verbose=true' if you want to see how the LR changes over time.

for epoch = 1 to 100 do
    // Compute the loss
    let pred = model.forward(dataBatch)
    let output = loss pred resultBatch

    // Clear the gradients before doing the back-propagation
    model.zero_grad()

    // Do back-progatation, which computes all the gradients.
    output.backward()

    scheduler.step() |> ignore

let pred = model.forward(dataBatch)
(loss pred resultBatch).item<single>()

Well, that was underwhelming. The loss (in my case) went up just a little bit, essentially a rounding error. For this trivial model, using a scheduler isn't going to make much of a difference, and it may not make much of a difference even for complex models. It's very hard to know until you try it, but now you know how to try it out.

In the future, TorchSharp will add more of the LR schedulers that are available for PyTorch, as well as allow them to be combined.