---
layout: post
title: Catalogue of LR Schedulers
{{cover-img}}
---

In [34]:
from pathlib import Path
import os
from collections import defaultdict

from IPython.display import HTML, Image
import numpy as np
import matplotlib.pyplot as plt
from celluloid import Camera
import torch.optim as optim
import torch

ROOT = Path("./assets/img/")

if not os.path.exists(ROOT):
    os.makedirs(ROOT)

In [12]:
def get_dummy_optimizer(lr):
    return optim.SGD([torch.tensor(0)], lr=lr)

# StepLR

In [32]:
def step_lr():

    optimizer = get_dummy_optimizer(lr=1.)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)

    epochs = np.arange(0, 200)
    
    lr_arr = []
    for epoch in epochs:
        lr = scheduler.get_last_lr()
        lr_arr.append(lr)
        
        optimizer.step()
        scheduler.step()

    lr_arr = np.array(lr_arr)
    log_lr_arr = np.log10(lr_arr)

    fig, (ax, ax_log) = plt.subplots(1, 2, figsize=(12,5))

    fig.suptitle("StepLR")

    ax.set_xlabel("epoch")
    ax.set_ylabel("lr")

    ax_log.set_xlabel("epoch")
    ax_log.set_ylabel("log(lr)")

    camera = Camera(fig)
    for epoch in epochs:
        ax.scatter(epoch, lr_arr[epoch], color="royalblue")
        ax.plot(epochs[:epoch+1], lr_arr[:epoch+1], color="royalblue")

        ax_log.scatter(epoch, log_lr_arr[epoch], color="royalblue")
        ax_log.plot(epochs[:epoch+1], log_lr_arr[:epoch+1], color="royalblue")
        
        camera.snap()

    anim = camera.animate()
    plt.close()

    gif_path = ROOT / "steplr.gif"  
    anim.save(gif_path, writer="pillow", fps=30)
    return Image(url=gif_path)

step_lr()

# ExponentialLR

In [45]:
def exp_lr():

    epochs = np.arange(0, 200)
    gammas = [0.9, 0.95, 0.99]
    
    lr_dict = defaultdict(list)
    for gamma in gammas:
        
        optimizer = get_dummy_optimizer(lr=1.)
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

        for epoch in epochs:
            lr = scheduler.get_last_lr()
            lr_dict[gamma].append(lr)
            
            optimizer.step()
            scheduler.step()

    fig, ax = plt.subplots(1, 1)

    fig.suptitle("ExponentialLR")

    ax.set_xlabel("epoch")
    ax.set_ylabel("lr")

    colors = {
        0.9: "royalblue",
        0.95: "orangered",
        0.99: "forestgreen",
    }
    
    camera = Camera(fig)
    for epoch in epochs:

        lines = []
        
        for gamma in lr_dict.keys():

            lr_arr = np.array(lr_dict[gamma])
            
            ax.scatter(epoch, lr_arr[epoch], color=colors[gamma])
            line, = ax.plot(epochs[:epoch+1], lr_arr[:epoch+1], 
                            color=colors[gamma], label=r"$\gamma=%s$" % gamma)

            lines.append(line)

        ax.legend(handles=lines)
        
        camera.snap()

    anim = camera.animate()
    plt.close()

    gif_path = ROOT / "explr.gif"  
    anim.save(gif_path, writer="pillow", fps=30)
    return Image(url=gif_path)

exp_lr()